LLMs/2. Advanced RAG Integration.ipynb
2025-02-27 02:21:19 -08:00

179 lines
7.4 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "636bba8f-4de0-434f-9064-818d96f628bf",
"metadata": {},
"outputs": [],
"source": [
"# ADVANCED RAG INTEGRATION\n",
"from ollama import chat\n",
"import numpy as np\n",
"import faiss\n",
"from sentence_transformers import SentenceTransformer\n",
"import os\n",
"import re\n",
"\n",
"DOCUMENT_PATHS = [\n",
" r'C:\\Users\\ASUS\\Downloads\\Hamrah.txt', #replace path\n",
" r'C:\\Users\\ASUS\\Downloads\\vape.txt',\n",
" r'C:\\Users\\ASUS\\Downloads\\Shah.txt',\n",
" r'C:\\Users\\ASUS\\Downloads\\Khalife.txt',\n",
" r'C:\\Users\\ASUS\\Downloads\\carbon.txt',\n",
" r'C:\\Users\\ASUS\\Downloads\\takapoo.txt',\n",
" r'C:\\Users\\ASUS\\Downloads\\mahmood.txt'\n",
"]\n",
"\n",
"EMBEDDING_MODEL = 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2'\n",
"LLM_MODEL = 'llama3.2'\n",
"CHUNK_SIZE = 1000\n",
"OVERLAP = 200\n",
"INDEX_PATH = r'C:\\Users\\ASUS\\Downloads\\doc_index.faiss'\n",
"CHUNK_MAP_PATH = r'C:\\Users\\ASUS\\Downloads\\chunk_map.npy'\n",
"\n",
"class AdvancedRAG:\n",
" def __init__(self):\n",
" self.encoder = SentenceTransformer(EMBEDDING_MODEL)\n",
" self.index = None\n",
" self.chunk_map = []\n",
" \n",
" def create_index(self):\n",
" \"\"\"Create FAISS index with cosine similarity and document mapping\"\"\"\n",
" all_chunks = []\n",
" doc_mapping = []\n",
" \n",
" # Process via CHUNKING (REQ 4 RAG)\n",
" for doc_idx, path in enumerate(DOCUMENT_PATHS):\n",
" with open(path, 'r', encoding='utf-8') as f:\n",
" text = re.sub(r'\\s+', ' ', f.read()).strip()\n",
" chunks = [text[i:i+CHUNK_SIZE] for i in range(0, len(text), CHUNK_SIZE - OVERLAP)]\n",
" all_chunks.extend(chunks)\n",
" doc_mapping.extend([doc_idx] * len(chunks))\n",
" \n",
" # Normalized embeddings (REQ 4 cosine similarity)\n",
" embeddings = self.encoder.encode(all_chunks)\n",
" faiss.normalize_L2(embeddings) \n",
" \n",
" # FAISS index & Mapping\n",
" self.index = faiss.IndexFlatIP(embeddings.shape[1])\n",
" self.index.add(embeddings.astype(np.float32))\n",
" self.chunk_map = np.array(doc_mapping)\n",
" \n",
" # Index \n",
" faiss.write_index(self.index, INDEX_PATH)\n",
" # Mapping \n",
" np.save(CHUNK_MAP_PATH, self.chunk_map)\n",
" \n",
" def load_index(self):\n",
" \"\"\"LOAD EXISTING DATA\"\"\"\n",
" self.index = faiss.read_index(INDEX_PATH)\n",
" self.chunk_map = np.load(CHUNK_MAP_PATH, allow_pickle=True)\n",
" \n",
" def query(self, question, doc_index, top_k=6):\n",
" \"\"\"DOCUMENT-SPECIFIC QUERY WITH COSINE SIMILARITY \"\"\"\n",
" # Encode \n",
" query_embed = self.encoder.encode([question])\n",
" # Normalize \n",
" faiss.normalize_L2(query_embed)\n",
" \n",
" distances, indices = self.index.search(query_embed.astype(np.float32), top_k*3)\n",
" \n",
" relevant_chunks = []\n",
" for idx in indices[0]:\n",
" if self.chunk_map[idx] == doc_index:\n",
" relevant_chunks.append(idx)\n",
" if len(relevant_chunks) >= top_k:\n",
" break\n",
" \n",
" return relevant_chunks\n",
"\n",
"class AnswerGenerator:\n",
" def __init__(self, rag_system):\n",
" self.rag = rag_system\n",
" self.chunks = [] \n",
" \n",
" def get_answer(self, question, doc_index):\n",
" \"\"\"GENERATING CONTEXT-AWARE ANSWER\"\"\"\n",
" if not self.chunks:\n",
" self._load_chunks()\n",
" \n",
" chunk_indices = self.rag.query(question, doc_index)\n",
" context = \"\\n\".join([self.chunks[idx] for idx in chunk_indices])\n",
" \n",
" prompt = f\"\"\"با استفاده از متن زیر به سوال پاسخ دهید:\n",
"{context}\n",
"\n",
"اگر پاسخ در متن وجود ندارد عبارت 'پاسخی یافت نشد' را برگردانید\n",
"\n",
"سوال: {question}\n",
"پاسخ:\"\"\"\n",
" \n",
" response = chat(model=LLM_MODEL, messages=[{'role': 'user', 'content': prompt}])\n",
" return response['message']['content']\n",
" \n",
" def _load_chunks(self):\n",
" \"\"\"LOAD ALL CHUNKS(LAZY)\"\"\"\n",
" self.chunks = []\n",
" for path in DOCUMENT_PATHS:\n",
" with open(path, 'r', encoding='utf-8') as f:\n",
" text = re.sub(r'\\s+', ' ', f.read()).strip()\n",
" self.chunks.extend([text[i:i+CHUNK_SIZE] for i in range(0, len(text), CHUNK_SIZE - OVERLAP)])\n",
"\n",
"# MAIN EXE of RAG\n",
"if __name__ == \"__main__\":\n",
" # RAG init\n",
" rag = AdvancedRAG()\n",
" \n",
" if not os.path.exists(INDEX_PATH):\n",
" print(\"Building optimized index...\")\n",
" rag.create_index()\n",
" else:\n",
" print(\"Loading existing index...\")\n",
" rag.load_index()\n",
" # Answer Generator init\n",
" generator = AnswerGenerator(rag)\n",
" \n",
" queries = [\n",
" (\"چرا اینترنت همراه اول گوشی وصل نمیشود؟\", 0),\n",
" (\"چطوری ویپ مورد نظرمو پیدا کنم؟\", 1),\n",
" (\"شاه عباس که بود؟\", 2),\n",
" (\"خلیفه سلطان که بود و چه کرد؟\", 3),\n",
" (\"کربن اکتیو و کربن بلک چه هستند و چه تفاوتی دارند و برای چه استفاده میشن؟\", 4),\n",
" (\"شرکت تکاپو صنعت نامی چه محصولاتی ارایه میدهد؟ چه چیزی این شرکت را منحصر به فرد میسازد؟ سهام این شرکت صعودی است یا نزولی؟\", 5),\n",
" (\"6 ,\"سید محمود خلیفه سلطانی کیست؟\"),\n",
" ]\n",
" \n",
" with open(r'C:\\Users\\ASUS\\Downloads\\representation.txt', 'w', encoding='utf-8') as f: #replace path\n",
" for q_idx, (query, doc_idx) in enumerate(queries):\n",
" answer = generator.get_answer(query, doc_idx)\n",
" f.write(f\"سوال {q_idx+1} ({doc_idx+1}):\\n{query}\\n\\nپاسخ:\\n{answer}\\n\\n{'='*50}\\n\\n\")\n",
" print(f\"پردازش سوال {q_idx+1}/{len(queries)} تکمیل شد\")\n",
"\n",
"print(\"تمامی سوالات با موفقیت پردازش شدند!\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}