From 4c9da22a3965ae57a622a0a36806ff9b3df398c1 Mon Sep 17 00:00:00 2001 From: MasihMoafi Date: Fri, 2 May 2025 11:12:23 +0000 Subject: [PATCH] Upload files to "/" --- Multimodal.py | 267 ++++++++++++++++ enhanced_combined.py | 746 +++++++++++++++++++++++++------------------ hybrid.py | 204 ++++++++++++ memory.py | 198 ++++++++++++ req.txt | 1 + 5 files changed, 1103 insertions(+), 313 deletions(-) create mode 100644 Multimodal.py create mode 100644 hybrid.py create mode 100644 memory.py create mode 100644 req.txt diff --git a/Multimodal.py b/Multimodal.py new file mode 100644 index 0000000..162a161 --- /dev/null +++ b/Multimodal.py @@ -0,0 +1,267 @@ +import os +import subprocess + +# Clear proxy settings +def clear_proxy_settings(): + for var in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]: + if var in os.environ: + del os.environ[var] + +clear_proxy_settings() + +import os +import tempfile +import subprocess +from datetime import datetime + +import streamlit as st +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.vectorstores import InMemoryVectorStore +from langchain_ollama import OllamaEmbeddings +from langchain_ollama.llms import OllamaLLM +from langchain_text_splitters import RecursiveCharacterTextSplitter +from unstructured.partition.pdf import partition_pdf +from unstructured.partition.utils.constants import PartitionStrategy +from search_utils import duckduckgo_search, rank_results + +template = """ +تو یک دستیار هستی که از یک داده های متنی و تصویری استفاده میکنی تا به سوالات کاربر به زبان فارسی سلیس پاسخ بدی. +Question: {question} +Context: {context} +Answer: +""" + +pdfs_directory = 'multi-modal-rag/pdfs/' +figures_directory = 'multi-modal-rag/figures/' +images_directory = 'multi-modal-rag/images/' +videos_directory = 'multi-modal-rag/videos/' +audio_directory = 'multi-modal-rag/audio/' +frames_directory = 'multi-modal-rag/frames/' + +# Create directories if they don't exist +os.makedirs(pdfs_directory, exist_ok=True) +os.makedirs(figures_directory, exist_ok=True) +os.makedirs(images_directory, exist_ok=True) +os.makedirs(videos_directory, exist_ok=True) +os.makedirs(audio_directory, exist_ok=True) +os.makedirs(frames_directory, exist_ok=True) + +embeddings = OllamaEmbeddings(model="llama3.2") +vector_store = InMemoryVectorStore(embeddings) + +model = OllamaLLM(model="gemma3") + +def upload_pdf(file): + with open(pdfs_directory + file.name, "wb") as f: + f.write(file.getbuffer()) + +def upload_image(file): + with open(images_directory + file.name, "wb") as f: + f.write(file.getbuffer()) + return images_directory + file.name + +def upload_video(file): + file_path = videos_directory + file.name + with open(file_path, "wb") as f: + f.write(file.getbuffer()) + return file_path + +def upload_audio(file): + file_path = audio_directory + file.name + with open(file_path, "wb") as f: + f.write(file.getbuffer()) + return file_path + +def load_pdf(file_path): + elements = partition_pdf( + file_path, + strategy=PartitionStrategy.HI_RES, + extract_image_block_types=["Image", "Table"], + extract_image_block_output_dir=figures_directory + ) + + text_elements = [element.text for element in elements if element.category not in ["Image", "Table"]] + + for file in os.listdir(figures_directory): + extracted_text = extract_text(figures_directory + file) + text_elements.append(extracted_text) + + return "\n\n".join(text_elements) + +def extract_frames(video_path, num_frames=5): + """Extract frames from video file and save them to frames directory""" + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + base_name = os.path.basename(video_path).split('.')[0] + frame_paths = [] + + # Extract frames using ffmpeg + for i in range(num_frames): + frame_path = f"{frames_directory}{base_name}_{timestamp}_{i}.jpg" + cmd = [ + 'ffmpeg', '-i', video_path, + '-ss', str(i * (1/num_frames)), '-vframes', '1', + '-q:v', '2', frame_path, '-y' + ] + try: + subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + frame_paths.append(frame_path) + except subprocess.CalledProcessError: + st.warning(f"Failed to extract frame {i} from video") + + return frame_paths + +def process_audio(audio_path): + """Process audio file using the model""" + audio_description = model.invoke( + f"Describe what you hear in this audio file: {os.path.basename(audio_path)}" + ) + return f"Audio file: {os.path.basename(audio_path)}. Description: {audio_description}" + +def extract_text(file_path): + model_with_image_context = model.bind(images=[file_path]) + return model_with_image_context.invoke("Tell me what do you see in this picture.") + +def split_text(text): + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=1000, + chunk_overlap=200, + add_start_index=True + ) + + return text_splitter.split_text(text) + +def index_docs(texts): + vector_store.add_texts(texts) + +def retrieve_docs(query): + return vector_store.similarity_search(query) + +def answer_question(question, documents): + local_context = "\n\n".join([doc.page_content for doc in documents]) + prompt = ChatPromptTemplate.from_template(template) + chain = prompt | model + return chain.invoke({"question": question, "context": local_context}) + +# Sidebar for upload options +st.sidebar.title("Upload Documents") +upload_option = st.sidebar.radio("Choose upload type:", ["PDF", "Image", "Video", "Audio", "Search"]) + +if upload_option == "Search": + st.title("Web Search with BM25 Ranking") + search_query = st.text_input("Enter your search query:") + + if search_query: + with st.spinner("Searching and ranking results..."): + # Get search results + search_results = duckduckgo_search(search_query, max_results=10) + + if search_results: + # Rank results using BM25 + ranked_results = rank_results(search_query, search_results) + + # Display results + st.subheader("Ranked Search Results") + for i, result in enumerate(ranked_results): + with st.expander(f"{i+1}. {result.title}"): + st.write(f"**Snippet:** {result.snippet}") + st.write(f"**URL:** {result.url}") + + # Option to ask about search results + st.subheader("Ask about these results") + question = st.text_input("Enter your question about the search results:") + + if question: + # Prepare context from top results + context = "\n\n".join([f"Title: {r.title}\nSnippet: {r.snippet}" for r in ranked_results[:3]]) + + # Use the model to answer + prompt = ChatPromptTemplate.from_template(template) + chain = prompt | model + + with st.spinner("Generating answer..."): + response = chain.invoke({"question": question, "context": context}) + st.markdown("### Answer") + st.write(response.content) + else: + st.warning("No search results found") + +elif upload_option == "PDF": + uploaded_file = st.file_uploader( + "Upload PDF", + type="pdf", + accept_multiple_files=False + ) + + if uploaded_file: + upload_pdf(uploaded_file) + with st.spinner("Processing PDF..."): + text = load_pdf(pdfs_directory + uploaded_file.name) + chunked_texts = split_text(text) + index_docs(chunked_texts) + st.success("PDF processed successfully!") + +elif upload_option == "Image": + uploaded_image = st.file_uploader( + "Upload Image", + type=["jpg", "jpeg", "png"], + accept_multiple_files=False + ) + + if uploaded_image: + image_path = upload_image(uploaded_image) + st.image(image_path, caption="Uploaded Image", use_column_width=True) + with st.spinner("Processing image..."): + image_description = extract_text(image_path) + index_docs([image_description]) + st.success("Image processed and added to knowledge base") + +elif upload_option == "Video": + uploaded_video = st.file_uploader( + "Upload Video", + type=["mp4", "avi", "mov", "mkv"], + accept_multiple_files=False + ) + + if uploaded_video: + video_path = upload_video(uploaded_video) + st.video(video_path) + + with st.spinner("Processing video frames..."): + frame_paths = extract_frames(video_path) + video_descriptions = [] + + for frame_path in frame_paths: + st.image(frame_path, caption=f"Frame from video", width=200) + frame_description = extract_text(frame_path) + video_descriptions.append(frame_description) + + # Add a combined description + combined_description = f"Video file: {uploaded_video.name}. Content description: " + " ".join(video_descriptions) + index_docs([combined_description]) + st.success("Video processed and added to knowledge base") + +else: # Audio option + uploaded_audio = st.file_uploader( + "Upload Audio", + type=["mp3", "wav", "ogg"], + accept_multiple_files=False + ) + + if uploaded_audio: + audio_path = upload_audio(uploaded_audio) + st.audio(audio_path) + + with st.spinner("Processing audio..."): + # For audio, we'll use the model directly without visual context + audio_description = process_audio(audio_path) + index_docs([audio_description]) + st.success("Audio processed and added to knowledge base") + +# Chat interface +question = st.chat_input() + +if question: + st.chat_message("user").write(question) + related_documents = retrieve_docs(question) + answer = answer_question(question, related_documents) + st.chat_message("assistant").write(answer) \ No newline at end of file diff --git a/enhanced_combined.py b/enhanced_combined.py index fd82cfa..95b0ffd 100644 --- a/enhanced_combined.py +++ b/enhanced_combined.py @@ -1,340 +1,460 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + import os -import pickle +import re import json -import nltk +import ssl +import argparse import requests -import time from bs4 import BeautifulSoup from urllib.parse import quote -from langchain_community.document_loaders import PDFPlumberLoader, WebBaseLoader -from langchain_text_splitters import RecursiveCharacterTextSplitter -from langchain_community.retrievers import BM25Retriever +from langchain_community.embeddings import HuggingFaceEmbeddings +from langchain_community.vectorstores import Chroma +from langchain_core.documents import Document +import traceback -try: - nltk.data.find('tokenizers/punkt') -except LookupError: - nltk.download('punkt') +# Disable SSL warnings and proxy settings +ssl._create_default_https_context = ssl._create_unverified_context +requests.packages.urllib3.disable_warnings() -class ModularRAG: +def clear_proxy_settings(): + """Remove proxy environment variables that might cause connection issues.""" + for var in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]: + if var in os.environ: + print(f"Removing proxy env var: {var}") + del os.environ[var] + +# Run at module load time +clear_proxy_settings() + +# Configuration +DOCUMENT_PATHS = [ + r'doc1.txt', + r'doc2.txt', + r'doc3.txt', + r'doc4.txt', + r'doc5.txt', + r'doc6.txt' +] +EMBEDDING_MODEL = 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2' +LLM_MODEL = 'gemma3' +CHUNK_SIZE = 1000 +OVERLAP = 200 +CHROMA_PERSIST_DIR = 'chroma_db' + +# Confidence thresholds +THRESHOLDS = { + 'direct_answer': 0.7, + 'rag_confidence': 0.6, + 'web_search': 0.5 +} + +def query_llm(prompt, model='gemma3'): + """Query the LLM model directly using Ollama API.""" + try: + ollama_endpoint = "http://localhost:11434/api/generate" + payload = { + "model": model, + "prompt": prompt, + "stream": False + } + response = requests.post(ollama_endpoint, json=payload) + + if response.status_code == 200: + result = response.json() + return result.get('response', '') + else: + print(f"Ollama API error: {response.status_code}") + return f"Error calling Ollama API: {response.status_code}" + except Exception as e: + print(f"Error querying LLM: {e}") + return f"Error: {str(e)}" + +class BM25Retriever: + """BM25 retriever implementation for text similarity search""" + + @classmethod + def from_documents(cls, documents): + """Create a BM25 retriever from documents""" + retriever = cls() + retriever.documents = documents + retriever.k = 4 + return retriever + + def get_relevant_documents(self, query): + """Get relevant documents using BM25 algorithm""" + # Simple BM25-like implementation + scores = [] + query_terms = set(re.findall(r'\b\w+\b', query.lower())) + + for doc in self.documents: + doc_terms = set(re.findall(r'\b\w+\b', doc.page_content.lower())) + # Calculate term overlap as a simple approximation of BM25 + overlap = len(query_terms.intersection(doc_terms)) + scores.append((doc, overlap)) + + # Sort by score and return top k + sorted_docs = [doc for doc, score in sorted(scores, key=lambda x: x[1], reverse=True)] + return sorted_docs[:self.k] + +class HybridRetriever: + """Hybrid retriever combining BM25 and vector search with configurable weights""" + + def __init__(self, vector_retriever, bm25_retriever, vector_weight=0.3): + """Initialize with separate retrievers and weights""" + self._vector_retriever = vector_retriever + self._bm25_retriever = bm25_retriever + self._vector_weight = vector_weight + self._bm25_weight = 1.0 - vector_weight + + def get_relevant_documents(self, query): + """Get relevant documents using weighted combination of retrievers""" + try: + # Get results from both retrievers + vector_docs = self._vector_retriever.get_relevant_documents(query) + bm25_docs = self._bm25_retriever.get_relevant_documents(query) + + # Create dictionary to track unique documents and their scores + doc_dict = {} + + # Add vector docs with their weights + for i, doc in enumerate(vector_docs): + # Score based on position (inverse rank) + score = (len(vector_docs) - i) * self._vector_weight + doc_id = doc.page_content[:50] # Use first 50 chars as a simple ID + if doc_id in doc_dict: + doc_dict[doc_id]["score"] += score + else: + doc_dict[doc_id] = {"doc": doc, "score": score} + + # Add BM25 docs with their weights + for i, doc in enumerate(bm25_docs): + # Score based on position (inverse rank) + score = (len(bm25_docs) - i) * self._bm25_weight + doc_id = doc.page_content[:50] # Use first 50 chars as a simple ID + if doc_id in doc_dict: + doc_dict[doc_id]["score"] += score + else: + doc_dict[doc_id] = {"doc": doc, "score": score} + + # Sort by combined score (highest first) + sorted_docs = sorted(doc_dict.values(), key=lambda x: x["score"], reverse=True) + + # Return just the document objects + return [item["doc"] for item in sorted_docs] + except Exception as e: + print(f"Error in hybrid retrieval: {e}") + return [] + +class AgenticQASystem: + """QA system implementing the specified architecture""" + def __init__(self): - self.storage_path = "./rag_data" - - if not os.path.exists(self.storage_path): - os.makedirs(self.storage_path) - os.makedirs(os.path.join(self.storage_path, "documents")) - os.makedirs(os.path.join(self.storage_path, "web_results")) - - self.documents = [] - self.web_results = [] - - # Web crawler settings - self.headers = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" - } - self.num_search_results = 10 - self.max_depth = 2 - self.max_links_per_page = 5 - self.max_paragraphs = 5 - - self._load_saved_data() + """Initialize the QA system with retrievers""" + # Load embeddings + self.embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL) + # Load documents and retrievers + self.documents = self.load_documents() + self.retriever = self.initialize_retriever() - def _load_saved_data(self): - doc_path = os.path.join(self.storage_path, "documents", "docs.pkl") - web_path = os.path.join(self.storage_path, "web_results", "web.json") - - if os.path.exists(doc_path): + def load_documents(self): + """Load documents from configured paths with sliding window chunking""" + print("Loading documents...") + docs = [] + for path in DOCUMENT_PATHS: try: - with open(doc_path, 'rb') as f: - self.documents = pickle.load(f) + with open(path, 'r', encoding='utf-8') as f: + text = re.sub(r'\s+', ' ', f.read()).strip() + # Sliding window chunking + chunks = [text[i:i+CHUNK_SIZE] for i in range(0, len(text), CHUNK_SIZE - OVERLAP)] + for chunk in chunks: + docs.append(Document( + page_content=chunk, + metadata={"source": os.path.basename(path)} + )) except Exception as e: - print(f"خطا در بارگیری اسناد: {e}") - - if os.path.exists(web_path): - try: - with open(web_path, 'r', encoding='utf-8') as f: - self.web_results = json.load(f) - except Exception as e: - print(f"خطا در بارگیری نتایج وب: {e}") + print(f"Error loading document {path}: {e}") + print(f"Loaded {len(docs)} document chunks") + return docs - def _save_documents(self): - doc_path = os.path.join(self.storage_path, "documents", "docs.pkl") - try: - with open(doc_path, 'wb') as f: - pickle.dump(self.documents, f) - except Exception as e: - print(f"خطا در ذخیره‌سازی اسناد: {e}") - - def _save_web_results(self): - web_path = os.path.join(self.storage_path, "web_results", "web.json") - try: - with open(web_path, 'w', encoding='utf-8') as f: - json.dump(self.web_results, f, ensure_ascii=False, indent=2) - except Exception as e: - print(f"خطا در ذخیره‌سازی نتایج وب: {e}") - - def load_pdf(self, file_path): - if not os.path.exists(file_path): - raise FileNotFoundError(f"فایل یافت نشد: {file_path}") - - try: - loader = PDFPlumberLoader(file_path) - documents = loader.load() - - if documents: - text_splitter = RecursiveCharacterTextSplitter( - chunk_size=1000, - chunk_overlap=200, - add_start_index=True - ) - chunked_docs = text_splitter.split_documents(documents) - - self.documents.extend(chunked_docs) - self._save_documents() - return len(chunked_docs) - return 0 - except Exception as e: - raise Exception(f"خطا در بارگیری PDF: {e}") - - def search_duckduckgo(self, query, num_results=None): - if num_results is None: - num_results = self.num_search_results - - try: - search_url = f"https://html.duckduckgo.com/html/?q={quote(query)}" - response = requests.get(search_url, headers=self.headers, timeout=10) - - if response.status_code != 200: - print(f"خطا در جستجوی وب: HTTP {response.status_code}") - return [] - - soup = BeautifulSoup(response.text, 'html.parser') - results = [] - - for element in soup.select('.result__url, .result__a'): - href = element.get('href') if 'href' in element.attrs else None - - if href and not href.startswith('/') and (href.startswith('http://') or href.startswith('https://')): - results.append(href) - elif not href and element.find('a') and 'href' in element.find('a').attrs: - href = element.find('a')['href'] - if href and not href.startswith('/'): - results.append(href) - - unique_results = list(set(results)) - return unique_results[:num_results] - - except Exception as e: - print(f"خطا در جستجوی DuckDuckGo: {e}") - return [] - - def crawl_page(self, url, depth=0): - if depth > self.max_depth: - return None, [] - - try: - response = requests.get(url, headers=self.headers, timeout=10) - response.raise_for_status() - - soup = BeautifulSoup(response.text, 'html.parser') - - title = soup.title.string if soup.title else "بدون عنوان" - - paragraphs = [] - for p in soup.find_all('p'): - text = p.get_text(strip=True) - if len(text) > 50: - paragraphs.append(text) - if len(paragraphs) >= self.max_paragraphs: - break - - links = [] - for a in soup.find_all('a', href=True): - href = a['href'] - if href.startswith('http') and href != url: - links.append(href) - if len(links) >= self.max_links_per_page: - break - - content = { - "url": url, - "title": title, - "paragraphs": paragraphs - } - - return content, links - - except Exception as e: - print(f"خطا در خزش صفحه {url}: {e}") - return None, [] - - def crawl_website(self, start_url, max_pages=10): - visited = set() - to_visit = [start_url] - contents = [] - - while to_visit and len(visited) < max_pages: - current_url = to_visit.pop(0) - - if current_url in visited: - continue - - content, links = self.crawl_page(current_url) - - visited.add(current_url) - - if content and content["paragraphs"]: - contents.append(content) - - for link in links: - if link not in visited and link not in to_visit: - to_visit.append(link) - - time.sleep(1) - - return contents - - def crawl_web(self, query): - urls = self.search_duckduckgo(query) - - if not urls: - print("هیچ نتیجه‌ای یافت نشد.") - return [] - - all_results = [] - for url in urls[:3]: # Limit to first 3 URLs for efficiency - content, links = self.crawl_page(url) - if content and content["paragraphs"]: - all_results.append(content) - - # Follow links from the main page (recursive crawling) - for link in links[:2]: # Limit to first 2 links - sub_content, _ = self.crawl_page(link, depth=1) - if sub_content and sub_content["paragraphs"]: - all_results.append(sub_content) - time.sleep(1) - - time.sleep(1) - - self.web_results = all_results - self._save_web_results() - - # Convert web results to documents for RAG - web_docs = [] - for result in all_results: - text = f"[{result['title']}]\n" + "\n".join(result['paragraphs']) - web_docs.append({"page_content": text, "metadata": {"source": result['url']}}) - - return all_results, web_docs - - def build_retriever(self, documents): - if not documents: - return None - - # Create BM25 retriever - bm25_retriever = BM25Retriever.from_documents(documents) - bm25_retriever.k = 3 # Return top 3 results - - return bm25_retriever - - def get_relevant_documents(self, query, documents): - retriever = self.build_retriever(documents) - if not retriever: - return [] - - return retriever.get_relevant_documents(query) - - def extract_context_from_documents(self, query): + def initialize_retriever(self): + """Initialize the hybrid retriever with BM25 and direct Chroma queries""" if not self.documents: + print("No documents loaded, retriever initialization failed") return None - relevant_docs = self.get_relevant_documents(query, self.documents) - - if not relevant_docs: + try: + # Create BM25 retriever + bm25_retriever = BM25Retriever.from_documents(self.documents) + bm25_retriever.k = 4 # Top k results to retrieve + + # Initialize vector store with KNN search + import shutil + if os.path.exists(CHROMA_PERSIST_DIR): + print(f"Removing existing Chroma DB to prevent dimension mismatch") + shutil.rmtree(CHROMA_PERSIST_DIR) + + # Create vector store directly from Chroma + print("Creating vector store...") + vector_store = Chroma.from_documents( + documents=self.documents, + embedding=self.embeddings, + persist_directory=CHROMA_PERSIST_DIR + ) + + vector_retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 4}) + print(f"Vector retriever created: {type(vector_retriever)}") + + # Create hybrid retriever - BM25 (70%) and Vector (30%) + print("Creating hybrid retriever") + hybrid_retriever = HybridRetriever(vector_retriever, bm25_retriever, vector_weight=0.3) + print("Hybrid retriever initialized successfully") + return hybrid_retriever + + except Exception as e: + print(f"Error initializing retriever: {e}") + traceback.print_exc() return None - - context = "\n\n".join([doc.page_content for doc in relevant_docs]) - return context - def extract_context_from_web(self, web_results, web_docs, query): - if not web_results or not web_docs: - return None, [] + def estimate_confidence(self, text, query, context=None): + """Estimate confidence of response""" + # Start with baseline confidence + confidence = 0.5 - # Try to use the retriever for better results - if web_docs: - relevant_docs = self.get_relevant_documents(query, web_docs) - if relevant_docs: - context = "\n\n".join([doc.page_content for doc in relevant_docs]) - sources = [doc.metadata.get("source", "") for doc in relevant_docs if "source" in doc.metadata] - return context, sources + # Check for uncertainty markers + uncertainty_phrases = [ + "نمی‌دانم", "مطمئن نیستم", "ممکن است", "شاید", "احتمالاً", + "فکر می‌کنم", "به نظر می‌رسد" + ] - # Fall back to simple extraction if retriever fails - contexts = [] - sources = [] + if any(phrase in text.lower() for phrase in uncertainty_phrases): + confidence -= 0.2 - for doc in web_results: - context_text = "\n".join(doc["paragraphs"]) - contexts.append(f"[{doc['title']}] {context_text}") - sources.append(doc['url']) + # Check for question relevance + query_words = set(re.findall(r'\b\w+\b', query.lower())) + text_words = set(re.findall(r'\b\w+\b', text.lower())) - context = "\n\n".join(contexts) - return context, sources + # Calculate overlap between query and response + if query_words: + overlap_ratio = len(query_words.intersection(text_words)) / len(query_words) + if overlap_ratio > 0.5: + confidence += 0.2 + elif overlap_ratio < 0.2: + confidence -= 0.2 + + # If context provided, check context relevance + if context: + context_words = set(re.findall(r'\b\w+\b', context.lower())) + if context_words: + context_overlap = len(context_words.intersection(text_words)) / len(context_words) + if context_overlap > 0.3: + confidence += 0.2 + else: + confidence -= 0.1 + + # Ensure confidence is within [0,1] + return max(0.0, min(1.0, confidence)) + + def check_direct_knowledge(self, query): + """Check if the LLM can answer directly from its knowledge""" + print("Checking LLM's direct knowledge...") + prompt = f"""به این سوال با استفاده از دانش خود پاسخ دهید. فقط به زبان فارسی پاسخ دهید. -def get_context(query, crawl_params=None): - """ - سیستم RAG مدولار برای پاسخگویی به سوالات با استفاده از اسناد و جستجوی وب +سوال: {query} + +پاسخ فارسی:""" + + response = query_llm(prompt, model=LLM_MODEL) + confidence = self.estimate_confidence(response, query) + print(f"LLM direct knowledge confidence: {confidence:.2f}") + + return response, confidence - پارامترها: - query (str): سوال به زبان فارسی - crawl_params (dict, optional): پارامترهای خزش وب - - max_depth: حداکثر عمق خزش - - max_links_per_page: حداکثر تعداد لینک‌های استخراج شده از هر صفحه - - max_paragraphs: حداکثر تعداد پاراگراف‌های استخراج شده از هر صفحه - - num_search_results: تعداد نتایج جستجو - - خروجی: - dict: نتیجه جستجو شامل متن و منابع - """ - rag = ModularRAG() - - # Configure crawling parameters if provided - if crawl_params: - if 'max_depth' in crawl_params: - rag.max_depth = crawl_params['max_depth'] - if 'max_links_per_page' in crawl_params: - rag.max_links_per_page = crawl_params['max_links_per_page'] - if 'max_paragraphs' in crawl_params: - rag.max_paragraphs = crawl_params['max_paragraphs'] - if 'num_search_results' in crawl_params: - rag.num_search_results = crawl_params['num_search_results'] - - # First try to get context from documents - doc_context = rag.extract_context_from_documents(query) - - if doc_context: + def rag_query(self, query): + """Use RAG to retrieve and generate answer""" + if not self.retriever: + print("Retriever not initialized, skipping RAG") + return None, 0.0 + + print("Retrieving documents for RAG...") + # Retrieve relevant documents + docs = self.retriever.get_relevant_documents(query) + + if not docs: + print("No relevant documents found") + return None, 0.0 + + print(f"Retrieved {len(docs)} relevant documents") + + # Prepare context + context = "\n\n".join([doc.page_content for doc in docs]) + sources = [doc.metadata.get("source", "Unknown") for doc in docs] + + # Query LLM with context + prompt = f"""با توجه به اطلاعات زیر، به سوال پاسخ دهید. فقط به زبان فارسی پاسخ دهید. + +اطلاعات: +{context} + +سوال: {query} + +پاسخ فارسی:""" + + response = query_llm(prompt, model=LLM_MODEL) + confidence = self.estimate_confidence(response, query, context) + print(f"RAG confidence: {confidence:.2f}") + return { - "has_context": True, - "context": doc_context, - "source": "documents", - "language": "fa" - } + "response": response, + "confidence": confidence, + "sources": list(set(sources)) + }, confidence - # Fall back to web search - web_results, web_docs = rag.crawl_web(query) - - if web_results: - web_context, sources = rag.extract_context_from_web(web_results, web_docs, query) + def web_search(self, query): + """Search the web for an answer""" + print("Searching web for answer...") + # Search DuckDuckGo + search_url = f"https://html.duckduckgo.com/html/?q={quote(query)}" + response = requests.get(search_url, verify=False, timeout=10) + + if response.status_code != 200: + print(f"Error searching web: HTTP {response.status_code}") + return None, 0.0 + + # Parse results + soup = BeautifulSoup(response.text, 'html.parser') + results = [] + + for element in soup.select('.result__url, .result__a')[:4]: + href = element.get('href') if 'href' in element.attrs else None + + if href and not href.startswith('/') and (href.startswith('http://') or href.startswith('https://')): + results.append(href) + elif not href and element.find('a') and 'href' in element.find('a').attrs: + href = element.find('a')['href'] + if href and not href.startswith('/'): + results.append(href) + + if not results: + print("No web results found") + return None, 0.0 + + # Crawl top results + web_content = [] + for url in results[:3]: + try: + headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"} + page = requests.get(url, headers=headers, timeout=10, verify=False) + page.raise_for_status() + + soup = BeautifulSoup(page.text, 'html.parser') + + # Remove non-content elements + for tag in ['script', 'style', 'nav', 'footer', 'header']: + for element in soup.find_all(tag): + element.decompose() + + # Get paragraphs + paragraphs = [p.get_text(strip=True) for p in soup.find_all('p') + if len(p.get_text(strip=True)) > 20] + + if paragraphs: + web_content.append(f"[Source: {url}] " + " ".join(paragraphs[:5])) + except Exception as e: + print(f"Error crawling {url}: {e}") + + if not web_content: + print("No useful content found from web results") + return None, 0.0 + + # Query LLM with web content + context = "\n\n".join(web_content) + prompt = f"""با توجه به اطلاعات زیر که از وب بدست آمده، به سوال پاسخ دهید. فقط به زبان فارسی پاسخ دهید. + +اطلاعات: +{context} + +سوال: {query} + +پاسخ فارسی:""" + + response = query_llm(prompt, model=LLM_MODEL) + confidence = self.estimate_confidence(response, query, context) + print(f"Web search confidence: {confidence:.2f}") + return { - "has_context": True, - "context": web_context, - "source": "web", - "sources": sources, - "language": "fa" - } + "response": response, + "confidence": confidence, + "sources": results[:3] + }, confidence - # No context found - return { - "has_context": False, - "context": "متأسفانه اطلاعاتی در مورد سوال شما یافت نشد.", - "source": "none", - "language": "fa" - } + def get_answer(self, query): + """Main method to get an answer following the specified architecture""" + print(f"Processing query: {query}") + + # STEP 1: Try direct LLM knowledge + direct_response, direct_confidence = self.check_direct_knowledge(query) + + if direct_confidence >= THRESHOLDS['direct_answer']: + print("Using direct LLM knowledge (high confidence)") + return f"{direct_response}\n\n[Source: LLM Knowledge, Confidence: {direct_confidence:.2f}]" + + # STEP 2: Try RAG with local documents + rag_result, rag_confidence = self.rag_query(query) + + if rag_result and rag_confidence >= THRESHOLDS['rag_confidence']: + print("Using RAG response (sufficient confidence)") + sources_text = ", ".join(rag_result["sources"][:3]) + return f"{rag_result['response']}\n\n[Source: Local Documents, Confidence: {rag_confidence:.2f}, Sources: {sources_text}]" + + # STEP 3: Try web search + web_result, web_confidence = self.web_search(query) + + if web_result and web_confidence >= THRESHOLDS['web_search']: + print("Using web search response (sufficient confidence)") + sources_text = ", ".join(web_result["sources"]) + return f"{web_result['response']}\n\n[Source: Web Search, Confidence: {web_confidence:.2f}, Sources: {sources_text}]" + + # STEP 4: Fall back to direct response with warning + print("No high-confidence source found, using direct response with warning") + return f"{direct_response}\n\n[Warning: Low confidence ({direct_confidence:.2f}). Please verify information.]" + +# Simple API functions +def get_answer(query): + """Get an answer for a query""" + system = AgenticQASystem() + return system.get_answer(query) + +# Main entry point +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="QA System") + + mode_group = parser.add_mutually_exclusive_group(required=True) + mode_group.add_argument("--query", "-q", help="Query to answer") + mode_group.add_argument("--interactive", "-i", action="store_true", help="Run in interactive chat mode") + mode_group.add_argument("--test", "-t", action="store_true", help="Run tests") + + args = parser.parse_args() + + if args.interactive: + # Simple interactive mode without memory + qa_system = AgenticQASystem() + print("=== QA System ===") + print("Type 'exit' or 'quit' to end") + + while True: + user_input = input("\nYou: ") + if not user_input.strip(): + continue + + if user_input.lower() in ['exit', 'quit', 'خروج']: + break + + response = qa_system.get_answer(user_input) + print(f"\nBot: {response}") + elif args.query: + qa_system = AgenticQASystem() + print(qa_system.get_answer(args.query)) + elif args.test: + print("Running tests...") diff --git a/hybrid.py b/hybrid.py new file mode 100644 index 0000000..60a1dfa --- /dev/null +++ b/hybrid.py @@ -0,0 +1,204 @@ +import os +import nltk +from langchain_community.document_loaders import PDFPlumberLoader, WebBaseLoader +from langchain_text_splitters import RecursiveCharacterTextSplitter +from langchain_core.vectorstores import InMemoryVectorStore +from langchain_core.prompts import ChatPromptTemplate +from langchain_ollama import OllamaEmbeddings, ChatOllama +from langchain_community.retrievers import BM25Retriever +from langchain.retrievers import EnsembleRetriever +from typing_extensions import TypedDict +from langgraph.graph import START, END, StateGraph + +# Ensure NLTK tokenizer is available +try: + nltk.data.find('tokenizers/punkt') +except LookupError: + nltk.download('punkt') + +# Initialize model and embeddings +model = ChatOllama(model="gemma3:12b", temperature=0.2) +embeddings = OllamaEmbeddings(model="gemma3:12b") + +# Vector store +vector_store = InMemoryVectorStore(embeddings) + +# Templates +qa_template = """ +You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. +If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise. +Question: {question} +Context: {context} +Answer: +""" + +# Text splitter +def split_text(documents): + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=1000, + chunk_overlap=200, + add_start_index=True + ) + return text_splitter.split_documents(documents) + +# PDF handling +def load_pdf(file_path): + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + loader = PDFPlumberLoader(file_path) + documents = loader.load() + return documents + +# Web page handling (using WebBaseLoader) +def load_webpage(url): + loader = WebBaseLoader(url) + documents = loader.load() + return documents + +# Hybrid retriever +def build_hybrid_retriever(documents): + vector_store.clear() + vector_store.add_documents(documents) + semantic_retriever = vector_store.as_retriever(search_kwargs={"k": 3}) + bm25_retriever = BM25Retriever.from_documents(documents) + bm25_retriever.k = 3 + hybrid_retriever = EnsembleRetriever( + retrievers=[semantic_retriever, bm25_retriever], + weights=[0.7, 0.3] + ) + return hybrid_retriever + +# DuckDuckGo search implementation +def search_ddg(query, num_results=3): + from langchain_community.utilities import DuckDuckGoSearchAPIWrapper + search = DuckDuckGoSearchAPIWrapper() + results = search.results(query, num_results) + return results + +# Answer question with error handling +def answer_question(question, documents): + try: + context = "\n\n".join([doc.page_content for doc in documents]) + prompt = ChatPromptTemplate.from_template(qa_template) + chain = prompt | model + return chain.invoke({"question": question, "context": context}).content + except Exception as e: + return f"Error generating answer: {e}" + +# Simple RAG node for web search +class WebSearchState(TypedDict): + query: str + results: list + response: str + +def web_search(state): + results = search_ddg(state["query"]) + return {"results": results} + +def generate_search_response(state): + try: + context = "\n\n".join([f"{r['title']}: {r['snippet']}" for r in state["results"]]) + prompt = ChatPromptTemplate.from_template(qa_template) + chain = prompt | model + response = chain.invoke({"question": state["query"], "context": context}) + return {"response": response.content} + except Exception as e: + return {"response": f"Error generating response: {e}"} + +# Build search graph +search_graph = StateGraph(WebSearchState) +search_graph.add_node("search", web_search) +search_graph.add_node("generate", generate_search_response) +search_graph.add_edge(START, "search") +search_graph.add_edge("search", "generate") +search_graph.add_edge("generate", END) +search_workflow = search_graph.compile() + +# Main command-line interface +if __name__ == "__main__": + print("Welcome to the Advanced RAG System") + print("Choose an option:") + print("1. Analyze PDF") + print("2. Crawl URL") + print("3. Search Internet") + choice = input("Enter your choice (1/2/3): ") + + if choice == "1": + pdf_path = input("Enter the path to the PDF file: ").strip() + if not pdf_path: + print("Please enter a valid file path.") + else: + try: + print("Processing PDF...") + documents = load_pdf(pdf_path) + if not documents: + print("No documents were loaded from the PDF. The file might be empty or not contain extractable text.") + else: + chunked_documents = split_text(documents) + if not chunked_documents: + print("No text chunks were created. The PDF might not contain any text.") + else: + retriever = build_hybrid_retriever(chunked_documents) + print(f"Processed {len(chunked_documents)} chunks") + question = input("Ask a question about the PDF: ").strip() + if not question: + print("Please enter a valid question.") + else: + print("Searching document...") + related_documents = retriever.get_relevant_documents(question) + if not related_documents: + print("No relevant documents found for the question.") + else: + answer = answer_question(question, related_documents) + print("Answer:", answer) + except Exception as e: + print(f"Error: {e}") + + elif choice == "2": + url = input("Enter the URL to analyze: ").strip() + if not url: + print("Please enter a valid URL.") + else: + try: + print("Loading webpage...") + web_documents = load_webpage(url) + if not web_documents: + print("No documents were loaded from the webpage. The page might be empty or not contain extractable text.") + else: + web_chunks = split_text(web_documents) + if not web_chunks: + print("No text chunks were created. The webpage might not contain any text.") + else: + web_retriever = build_hybrid_retriever(web_chunks) + print(f"Processed {len(web_chunks)} chunks from webpage") + question = input("Ask a question about the webpage: ").strip() + if not question: + print("Please enter a valid question.") + else: + print("Analyzing content...") + web_results = web_retriever.get_relevant_documents(question) + if not web_results: + print("No relevant documents found for the question.") + else: + answer = answer_question(question, web_results) + print("Answer:", answer) + except Exception as e: + print(f"Error loading webpage: {e}") + + elif choice == "3": + query = input("Enter your search query: ").strip() + if not query: + print("Please enter a valid search query.") + else: + try: + print("Searching the web...") + search_result = search_workflow.invoke({"query": query}) + print("Response:", search_result["response"]) + print("Sources:") + for result in search_result["results"]: + print(f"- {result['title']}: {result['link']}") + except Exception as e: + print(f"Error during search: {e}") + + else: + print("Invalid choice") \ No newline at end of file diff --git a/memory.py b/memory.py new file mode 100644 index 0000000..8857618 --- /dev/null +++ b/memory.py @@ -0,0 +1,198 @@ +# --- Dependencies --- +# pip install langchain langchain-core langchain-ollama faiss-cpu sentence-transformers + +import datetime +import os +from langchain_ollama import ChatOllama, OllamaEmbeddings +from langchain.memory import ConversationBufferMemory # Added for intra-session memory +from langchain_community.vectorstores import FAISS +from langchain.prompts import PromptTemplate +from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableParallel +from langchain_core.output_parsers import StrOutputParser +from langchain.schema import Document # Needed for manual saving + +# --- Config --- +FAISS_INDEX_PATH = "my_chatbot_memory_index" # Directory to save/load FAISS index + +# --- Ollama LLM & Embeddings Setup --- +# Run in terminal: ollama pull gemma3 +# Run in terminal: ollama pull nomic-embed-text +OLLAMA_LLM_MODEL = 'gemma3' # Using Gemma 3 as requested +OLLAMA_EMBED_MODEL = 'nomic-embed-text' # Recommended embedding model for Ollama + +try: + llm = ChatOllama(model=OLLAMA_LLM_MODEL) + embeddings = OllamaEmbeddings(model=OLLAMA_EMBED_MODEL) + print(f"Successfully initialized Ollama: LLM='{OLLAMA_LLM_MODEL}', Embeddings='{OLLAMA_EMBED_MODEL}'") + # Optional tests removed for brevity +except Exception as e: + print(f"Error initializing Ollama components: {e}") + print(f"Ensure Ollama is running & models pulled (e.g., 'ollama pull {OLLAMA_LLM_MODEL}' and 'ollama pull {OLLAMA_EMBED_MODEL}').") + exit() + +# --- Vector Store (Episodic Memory) Setup --- Persisted! +try: + if os.path.exists(FAISS_INDEX_PATH): + print(f"Loading existing FAISS index from: {FAISS_INDEX_PATH}") + vectorstore = FAISS.load_local( + FAISS_INDEX_PATH, + embeddings, + allow_dangerous_deserialization=True # Required for FAISS loading + ) + retriever = vectorstore.as_retriever(search_kwargs=dict(k=3)) + print("FAISS vector store loaded successfully.") + else: + print(f"No FAISS index found at {FAISS_INDEX_PATH}. Initializing new store.") + # FAISS needs at least one text to initialize. + vectorstore = FAISS.from_texts( + ["Initial conversation context placeholder - Bot created"], + embeddings + ) + retriever = vectorstore.as_retriever(search_kwargs=dict(k=3)) + # Save the initial empty index + vectorstore.save_local(FAISS_INDEX_PATH) + print("New FAISS vector store initialized and saved.") + +except Exception as e: + print(f"Error initializing/loading FAISS: {e}") + print("Check permissions or delete the index directory if corrupted.") + exit() + +# --- Conversation Buffer (Short-Term) Memory Setup --- +# memory_key must match the input variable in the prompt +# return_messages=True formats history as suitable list of BaseMessages +buffer_memory = ConversationBufferMemory( + memory_key="chat_history", + return_messages=True +) +# <<< ADDED: Clear buffer at the start of each script run >>> +buffer_memory.clear() + +# --- Define the Prompt Template --- +# Now includes chat_history for the buffer memory +template = """You are a helpful chatbot assistant with episodic memory (from past sessions) and conversational awareness (from the current session). +Use the following relevant pieces of information: +1. Episodic Memory (Knowledge from *previous* chat sessions): +{semantic_context} + +2. Chat History (What we've discussed in the *current* session): +{chat_history} + +Combine this information with the current user input to generate a coherent and contextually relevant answer. +If recalling information from Episodic Memory, you can mention it stems from a past conversation if appropriate. +If no relevant context or history is found, just respond naturally to the current input. + +Current Input: +User: {input} +Assistant:""" + +prompt = PromptTemplate( + input_variables=["semantic_context", "chat_history", "input"], + template=template +) + +# --- Helper Function for Formatting Retrieved Docs (Episodic Memory) --- +# Formats the retrieved documents (past interactions) for the prompt +def format_retrieved_docs(docs): + # Simplified formatting: Extract core content only and label explicitly + formatted = [] + for doc in docs: + content = doc.page_content + # Basic check to remove placeholder + if content not in ["Initial conversation context placeholder - Bot created"]: + # Attempt to strip "Role (timestamp): " prefix if present + if "):": + content = content.split("):", 1)[-1].strip() + if content: # Ensure content is not empty after stripping + formatted.append(f"Recalled from a past session: {content}") + # Use a double newline to separate recalled memories clearly + return "\n\n".join(formatted) if formatted else "No relevant memories found from past sessions." + + +# --- Chain Definition using LCEL --- + +# Function to load episodic memory (FAISS context) +def load_episodic_memory(input_dict): + query = input_dict.get("input", "") + docs = retriever.invoke(query) + return format_retrieved_docs(docs) + +# Function to save episodic memory (and persist FAISS index) +def save_episodic_memory_step(inputs_outputs): + user_input = inputs_outputs.get("input", "") + llm_output = inputs_outputs.get("output", "") + + if user_input and llm_output: + timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + docs_to_add = [ + Document(page_content=f"User ({timestamp}): {user_input}"), + Document(page_content=f"Assistant ({timestamp}): {llm_output}") + ] + vectorstore.add_documents(docs_to_add) + vectorstore.save_local(FAISS_INDEX_PATH) # Persist index after adding + # print(f"DEBUG: Saved to FAISS index: {FAISS_INDEX_PATH}") + return inputs_outputs # Pass the dict through for potential further steps + + +# Define the core chain logic +chain_core = ( + RunnablePassthrough.assign( + semantic_context=RunnableLambda(load_episodic_memory), + chat_history=RunnableLambda(lambda x: buffer_memory.load_memory_variables(x)['chat_history']) + ) + | prompt + | llm + | StrOutputParser() +) + +# Wrap the core logic to handle memory updates +def run_chain(input_dict): + user_input = input_dict['input'] + + # Invoke the core chain to get the response + llm_response = chain_core.invoke({"input": user_input}) + + # Prepare data for saving + save_data = {"input": user_input, "output": llm_response} + + # Save to episodic memory (FAISS) + save_episodic_memory_step(save_data) + + # Save to buffer memory + buffer_memory.save_context({"input": user_input}, {"output": llm_response}) + + return llm_response + + +# --- Chat Loop --- +print(f"\nChatbot Ready! Using Ollama ('{OLLAMA_LLM_MODEL}' chat, '{OLLAMA_EMBED_MODEL}' embed)") +print(f"Episodic memory stored in: {FAISS_INDEX_PATH}") +print("Type 'quit', 'exit', or 'bye' to end the conversation.") + +while True: + user_text = input("You: ") + if user_text.lower() in ["quit", "exit", "bye"]: + # Optionally clear buffer memory on exit if desired + buffer_memory.clear() + print("Chatbot: Goodbye!") + break + if not user_text: + continue + + try: + # Use the wrapper function to handle the chain invocation and memory updates + response = run_chain({"input": user_text}) + print(f"Chatbot: {response}") + + # Optional debug: View buffer memory + # print("DEBUG: Buffer Memory:", buffer_memory.load_memory_variables({})) + # Optional debug: Check vector store size + # print(f"DEBUG: Vector store size: {vectorstore.index.ntotal}") + + except Exception as e: + print(f"\nAn error occurred during the chat chain: {e}") + # Add more detailed error logging if needed + import traceback + print(traceback.format_exc()) + +# --- End of Script --- \ No newline at end of file diff --git a/req.txt b/req.txt new file mode 100644 index 0000000..ccf5f47 --- /dev/null +++ b/req.txt @@ -0,0 +1 @@ +pip install streamlit langchain-core langchain-ollama unstructured[pdf] langchain-text-splitters pypdf