LLMs/3. Chroma Implementation.ipynb
2025-02-27 02:21:19 -08:00

161 lines
6.5 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "20da3ce4-6291-40de-8068-e66beb639137",
"metadata": {},
"outputs": [],
"source": [
"# CHROMOA\n",
"from langchain.vectorstores import Chroma\n",
"from langchain.embeddings import SentenceTransformerEmbeddings\n",
"from langchain.schema import Document\n",
"from ollama import chat\n",
"import os\n",
"import re\n",
"\n",
"DOCUMENT_PATHS = [\n",
" r'/home/masih/rag_data/Hamrah.txt', #replace path\n",
" r'/home/masih/rag_data/vape.txt',\n",
" r'/home/masih/rag_data/Shah.txt',\n",
" r'/home/masih/rag_data/Khalife.txt',\n",
" r'/home/masih/rag_data/carbon.txt',\n",
" r'/home/masih/rag_data/takapoo.txt'\n",
"]\n",
"\n",
"EMBEDDING_MODEL = 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2'\n",
"LLM_MODEL = 'gemma2:9b'\n",
"CHUNK_SIZE = 1000\n",
"OVERLAP = 200\n",
"CHROMA_PERSIST_DIR = r'\\home\\Masih\\chroma_db\\chroma_db' \n",
"\n",
"class ChromaRAGSystem:\n",
" def __init__(self):\n",
" # Init embedding model\n",
" self.embeddings = SentenceTransformerEmbeddings(model_name=EMBEDDING_MODEL)\n",
" # Vector store instance\n",
" self.vector_db = None\n",
" \n",
" def build_vector_store(self):\n",
" \"\"\"Process documents and create Chroma vector store\"\"\"\n",
" all_docs = []\n",
" \n",
"\n",
" for doc_idx, path in enumerate(DOCUMENT_PATHS):\n",
" with open(path, 'r', encoding='utf-8') as f:\n",
" text = re.sub(r'\\s+', ' ', f.read()).strip()\n",
" # sliding window chunking\n",
" chunks = [\n",
" text[i:i+CHUNK_SIZE] \n",
" for i in range(0, len(text), CHUNK_SIZE - OVERLAP)\n",
" ]\n",
" # LangChain documents with metadata\n",
" for chunk in chunks:\n",
" all_docs.append(Document(\n",
" page_content=chunk,\n",
" metadata={\"source_doc\": doc_idx}\n",
" ))\n",
" \n",
" # Chroma vector store\n",
" self.vector_db = Chroma.from_documents(\n",
" documents=all_docs,\n",
" embedding=self.embeddings,\n",
" persist_directory=CHROMA_PERSIST_DIR\n",
" )\n",
" self.vector_db.persist()\n",
" \n",
" def load_vector_store(self):\n",
" \"\"\"Load existing Chroma vector store\"\"\"\n",
" self.vector_db = Chroma(\n",
" persist_directory=CHROMA_PERSIST_DIR,\n",
" embedding_function=self.embeddings\n",
" )\n",
" \n",
" def document_query(self, query, doc_index, top_k=5):\n",
" \"\"\"Retrieve context from specific document\"\"\"\n",
" # Chroma metadata filtering\n",
" results = self.vector_db.similarity_search(\n",
" query=query,\n",
" k=top_k,\n",
" filter={\"source_doc\": doc_index}\n",
" )\n",
" return [doc.page_content for doc in results]\n",
"\n",
"class AnswerGenerator:\n",
" def __init__(self, rag_system):\n",
" self.rag = rag_system\n",
" \n",
" def generate_response(self, question, doc_index):\n",
" \"\"\"Generate context-aware answer using LLM\"\"\"\n",
" # Retrieve relevant context\n",
" context_chunks = self.rag.document_query(question, doc_index)\n",
" context = \"\\n\".join(context_chunks)\n",
" \n",
" prompt = f\"\"\"با استفاده از متن زیر به سوال پاسخ دهید:\n",
"{context}\n",
"\n",
"اگر پاسخ در متن وجود ندارد عبارت 'پاسخی یافت نشد' را برگردانید\n",
"\n",
"سوال: {question}\n",
"پاسخ:\"\"\"\n",
" \n",
" response = chat(model=LLM_MODEL, messages=[{'role': 'user', 'content': prompt}])\n",
" return response['message']['content']\n",
"\n",
"if __name__ == \"__main__\":\n",
" rag_system = ChromaRAGSystem()\n",
" \n",
" # Init vector store\n",
" if not os.path.exists(CHROMA_PERSIST_DIR):\n",
" print(\"Creating new vector store...\")\n",
" rag_system.build_vector_store()\n",
" else:\n",
" print(\"Loading existing vector store...\")\n",
" rag_system.load_vector_store()\n",
" \n",
" # Init answer generator\n",
" answer_engine = AnswerGenerator(rag_system)\n",
"\n",
" queries = [\n",
" (\"چرا اینترنت همراه اول گوشی وصل نمیشود؟\", 0),\n",
" (\"چطوری ویپ مورد نظرمو پیدا کنم؟\", 1),\n",
" (\"شاه عباس که بود؟\", 2),\n",
" (\"خلیفه سلطان که بود و چه کرد؟\", 3),\n",
" (\"کربن اکتیو و کربن بلک چه هستند و چه تفاوتی دارند و برای چه استفاده میشن؟\", 4),\n",
" (\"شرکت تکاپو صنعت نامی چه محصولاتی ارایه میدهد؟ چه چیزی این شرکت را منحصر به فرد میسازد؟ سهام این شرکت صعودی است یا نزولی؟\", 5)\n",
" ]\n",
" \n",
" with open( r'/home/masih/rag_data/response.txt', 'w', encoding='utf-8') as output_file: #repalce path\n",
" for q_num, (query, doc_idx) in enumerate(queries):\n",
" answer = answer_engine.generate_response(query, doc_idx)\n",
" output_file.write(f\"سوال {q_num+1} ({doc_idx+1}):\\n{query}\\n\\nپاسخ:\\n{answer}\\n\\n{'='*50}\\n\\n\")\n",
" print(f\"پردازش سوال {q_num+1}/{len(queries)} تکمیل شد\")\n",
"\n",
"print(\"تمامی سوالات با موفقیت پردازش شدند!\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}