From fd7ba8b68c8ee351cf8bedce982db57ec798b61c Mon Sep 17 00:00:00 2001 From: MasihMoafi Date: Fri, 2 May 2025 06:44:53 +0000 Subject: [PATCH] Add enhanced_combined.py --- enhanced_combined.py | 460 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 460 insertions(+) create mode 100644 enhanced_combined.py diff --git a/enhanced_combined.py b/enhanced_combined.py new file mode 100644 index 0000000..77731c7 --- /dev/null +++ b/enhanced_combined.py @@ -0,0 +1,460 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os +import re +import json +import ssl +import argparse +import requests +from bs4 import BeautifulSoup +from urllib.parse import quote +from langchain_community.embeddings import HuggingFaceEmbeddings +from langchain_community.vectorstores import Chroma +from langchain_core.documents import Document +import traceback + +# Disable SSL warnings and proxy settings +ssl._create_default_https_context = ssl._create_unverified_context +requests.packages.urllib3.disable_warnings() + +def clear_proxy_settings(): + """Remove proxy environment variables that might cause connection issues.""" + for var in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]: + if var in os.environ: + print(f"Removing proxy env var: {var}") + del os.environ[var] + +# Run at module load time +clear_proxy_settings() + +# Configuration +DOCUMENT_PATHS = [ + r'doc1.txt', + r'doc2.txt', + r'doc3.txt', + r'doc4.txt', + r'doc5.txt', + r'doc6.txt' +] +EMBEDDING_MODEL = 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2' +LLM_MODEL = 'gemma3' +CHUNK_SIZE = 1000 +OVERLAP = 200 +CHROMA_PERSIST_DIR = 'chroma_db' + +# Confidence thresholds +THRESHOLDS = { + 'direct_answer': 0.7, + 'rag_confidence': 0.6, + 'web_search': 0.5 +} + +def query_llm(prompt, model='gemma3'): + """Query the LLM model directly using Ollama API.""" + try: + ollama_endpoint = "http://localhost:11434/api/generate" + payload = { + "model": model, + "prompt": prompt, + "stream": False + } + response = requests.post(ollama_endpoint, json=payload) + + if response.status_code == 200: + result = response.json() + return result.get('response', '') + else: + print(f"Ollama API error: {response.status_code}") + return f"Error calling Ollama API: {response.status_code}" + except Exception as e: + print(f"Error querying LLM: {e}") + return f"Error: {str(e)}" + +class BM25Retriever: + """BM25 retriever implementation for text similarity search""" + + @classmethod + def from_documents(cls, documents): + """Create a BM25 retriever from documents""" + retriever = cls() + retriever.documents = documents + retriever.k = 4 + return retriever + + def get_relevant_documents(self, query): + """Get relevant documents using BM25 algorithm""" + # Simple BM25-like implementation + scores = [] + query_terms = set(re.findall(r'\b\w+\b', query.lower())) + + for doc in self.documents: + doc_terms = set(re.findall(r'\b\w+\b', doc.page_content.lower())) + # Calculate term overlap as a simple approximation of BM25 + overlap = len(query_terms.intersection(doc_terms)) + scores.append((doc, overlap)) + + # Sort by score and return top k + sorted_docs = [doc for doc, score in sorted(scores, key=lambda x: x[1], reverse=True)] + return sorted_docs[:self.k] + +class HybridRetriever: + """Hybrid retriever combining BM25 and vector search with configurable weights""" + + def __init__(self, vector_retriever, bm25_retriever, vector_weight=0.3): + """Initialize with separate retrievers and weights""" + self._vector_retriever = vector_retriever + self._bm25_retriever = bm25_retriever + self._vector_weight = vector_weight + self._bm25_weight = 1.0 - vector_weight + + def get_relevant_documents(self, query): + """Get relevant documents using weighted combination of retrievers""" + try: + # Get results from both retrievers + vector_docs = self._vector_retriever.get_relevant_documents(query) + bm25_docs = self._bm25_retriever.get_relevant_documents(query) + + # Create dictionary to track unique documents and their scores + doc_dict = {} + + # Add vector docs with their weights + for i, doc in enumerate(vector_docs): + # Score based on position (inverse rank) + score = (len(vector_docs) - i) * self._vector_weight + doc_id = doc.page_content[:50] # Use first 50 chars as a simple ID + if doc_id in doc_dict: + doc_dict[doc_id]["score"] += score + else: + doc_dict[doc_id] = {"doc": doc, "score": score} + + # Add BM25 docs with their weights + for i, doc in enumerate(bm25_docs): + # Score based on position (inverse rank) + score = (len(bm25_docs) - i) * self._bm25_weight + doc_id = doc.page_content[:50] # Use first 50 chars as a simple ID + if doc_id in doc_dict: + doc_dict[doc_id]["score"] += score + else: + doc_dict[doc_id] = {"doc": doc, "score": score} + + # Sort by combined score (highest first) + sorted_docs = sorted(doc_dict.values(), key=lambda x: x["score"], reverse=True) + + # Return just the document objects + return [item["doc"] for item in sorted_docs] + except Exception as e: + print(f"Error in hybrid retrieval: {e}") + return [] + +class AgenticQASystem: + """QA system implementing the specified architecture""" + + def __init__(self): + """Initialize the QA system with retrievers""" + # Load embeddings + self.embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL) + # Load documents and retrievers + self.documents = self.load_documents() + self.retriever = self.initialize_retriever() + + def load_documents(self): + """Load documents from configured paths with sliding window chunking""" + print("Loading documents...") + docs = [] + for path in DOCUMENT_PATHS: + try: + with open(path, 'r', encoding='utf-8') as f: + text = re.sub(r'\s+', ' ', f.read()).strip() + # Sliding window chunking + chunks = [text[i:i+CHUNK_SIZE] for i in range(0, len(text), CHUNK_SIZE - OVERLAP)] + for chunk in chunks: + docs.append(Document( + page_content=chunk, + metadata={"source": os.path.basename(path)} + )) + except Exception as e: + print(f"Error loading document {path}: {e}") + print(f"Loaded {len(docs)} document chunks") + return docs + + def initialize_retriever(self): + """Initialize the hybrid retriever with BM25 and direct Chroma queries""" + if not self.documents: + print("No documents loaded, retriever initialization failed") + return None + + try: + # Create BM25 retriever + bm25_retriever = BM25Retriever.from_documents(self.documents) + bm25_retriever.k = 4 # Top k results to retrieve + + # Initialize vector store with KNN search + import shutil + if os.path.exists(CHROMA_PERSIST_DIR): + print(f"Removing existing Chroma DB to prevent dimension mismatch") + shutil.rmtree(CHROMA_PERSIST_DIR) + + # Create vector store directly from Chroma + print("Creating vector store...") + vector_store = Chroma.from_documents( + documents=self.documents, + embedding=self.embeddings, + persist_directory=CHROMA_PERSIST_DIR + ) + + vector_retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 4}) + print(f"Vector retriever created: {type(vector_retriever)}") + + # Create hybrid retriever - BM25 (70%) and Vector (30%) + print("Creating hybrid retriever") + hybrid_retriever = HybridRetriever(vector_retriever, bm25_retriever, vector_weight=0.3) + print("Hybrid retriever initialized successfully") + return hybrid_retriever + + except Exception as e: + print(f"Error initializing retriever: {e}") + traceback.print_exc() + return None + + def estimate_confidence(self, text, query, context=None): + """Estimate confidence of response""" + # Start with baseline confidence + confidence = 0.5 + + # Check for uncertainty markers + uncertainty_phrases = [ + "نمی‌دانم", "مطمئن نیستم", "ممکن است", "شاید", "احتمالاً", + "فکر می‌کنم", "به نظر می‌رسد" + ] + + if any(phrase in text.lower() for phrase in uncertainty_phrases): + confidence -= 0.2 + + # Check for question relevance + query_words = set(re.findall(r'\b\w+\b', query.lower())) + text_words = set(re.findall(r'\b\w+\b', text.lower())) + + # Calculate overlap between query and response + if query_words: + overlap_ratio = len(query_words.intersection(text_words)) / len(query_words) + if overlap_ratio > 0.5: + confidence += 0.2 + elif overlap_ratio < 0.2: + confidence -= 0.2 + + # If context provided, check context relevance + if context: + context_words = set(re.findall(r'\b\w+\b', context.lower())) + if context_words: + context_overlap = len(context_words.intersection(text_words)) / len(context_words) + if context_overlap > 0.3: + confidence += 0.2 + else: + confidence -= 0.1 + + # Ensure confidence is within [0,1] + return max(0.0, min(1.0, confidence)) + + def check_direct_knowledge(self, query): + """Check if the LLM can answer directly from its knowledge""" + print("Checking LLM's direct knowledge...") + prompt = f"""به این سوال با استفاده از دانش خود پاسخ دهید. فقط به زبان فارسی پاسخ دهید. + +سوال: {query} + +پاسخ فارسی:""" + + response = query_llm(prompt, model=LLM_MODEL) + confidence = self.estimate_confidence(response, query) + print(f"LLM direct knowledge confidence: {confidence:.2f}") + + return response, confidence + + def rag_query(self, query): + """Use RAG to retrieve and generate answer""" + if not self.retriever: + print("Retriever not initialized, skipping RAG") + return None, 0.0 + + print("Retrieving documents for RAG...") + # Retrieve relevant documents + docs = self.retriever.get_relevant_documents(query) + + if not docs: + print("No relevant documents found") + return None, 0.0 + + print(f"Retrieved {len(docs)} relevant documents") + + # Prepare context + context = "\n\n".join([doc.page_content for doc in docs]) + sources = [doc.metadata.get("source", "Unknown") for doc in docs] + + # Query LLM with context + prompt = f"""با توجه به اطلاعات زیر، به سوال پاسخ دهید. فقط به زبان فارسی پاسخ دهید. + +اطلاعات: +{context} + +سوال: {query} + +پاسخ فارسی:""" + + response = query_llm(prompt, model=LLM_MODEL) + confidence = self.estimate_confidence(response, query, context) + print(f"RAG confidence: {confidence:.2f}") + + return { + "response": response, + "confidence": confidence, + "sources": list(set(sources)) + }, confidence + + def web_search(self, query): + """Search the web for an answer""" + print("Searching web for answer...") + # Search DuckDuckGo + search_url = f"https://html.duckduckgo.com/html/?q={quote(query)}" + response = requests.get(search_url, verify=False, timeout=10) + + if response.status_code != 200: + print(f"Error searching web: HTTP {response.status_code}") + return None, 0.0 + + # Parse results + soup = BeautifulSoup(response.text, 'html.parser') + results = [] + + for element in soup.select('.result__url, .result__a')[:4]: + href = element.get('href') if 'href' in element.attrs else None + + if href and not href.startswith('/') and (href.startswith('http://') or href.startswith('https://')): + results.append(href) + elif not href and element.find('a') and 'href' in element.find('a').attrs: + href = element.find('a')['href'] + if href and not href.startswith('/'): + results.append(href) + + if not results: + print("No web results found") + return None, 0.0 + + # Crawl top results + web_content = [] + for url in results[:3]: + try: + headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"} + page = requests.get(url, headers=headers, timeout=10, verify=False) + page.raise_for_status() + + soup = BeautifulSoup(page.text, 'html.parser') + + # Remove non-content elements + for tag in ['script', 'style', 'nav', 'footer', 'header']: + for element in soup.find_all(tag): + element.decompose() + + # Get paragraphs + paragraphs = [p.get_text(strip=True) for p in soup.find_all('p') + if len(p.get_text(strip=True)) > 20] + + if paragraphs: + web_content.append(f"[Source: {url}] " + " ".join(paragraphs[:5])) + except Exception as e: + print(f"Error crawling {url}: {e}") + + if not web_content: + print("No useful content found from web results") + return None, 0.0 + + # Query LLM with web content + context = "\n\n".join(web_content) + prompt = f"""با توجه به اطلاعات زیر که از وب بدست آمده، به سوال پاسخ دهید. فقط به زبان فارسی پاسخ دهید. + +اطلاعات: +{context} + +سوال: {query} + +پاسخ فارسی:""" + + response = query_llm(prompt, model=LLM_MODEL) + confidence = self.estimate_confidence(response, query, context) + print(f"Web search confidence: {confidence:.2f}") + + return { + "response": response, + "confidence": confidence, + "sources": results[:3] + }, confidence + + def get_answer(self, query): + """Main method to get an answer following the specified architecture""" + print(f"Processing query: {query}") + + # STEP 1: Try direct LLM knowledge + direct_response, direct_confidence = self.check_direct_knowledge(query) + + if direct_confidence >= THRESHOLDS['direct_answer']: + print("Using direct LLM knowledge (high confidence)") + return f"{direct_response}\n\n[Source: LLM Knowledge, Confidence: {direct_confidence:.2f}]" + + # STEP 2: Try RAG with local documents + rag_result, rag_confidence = self.rag_query(query) + + if rag_result and rag_confidence >= THRESHOLDS['rag_confidence']: + print("Using RAG response (sufficient confidence)") + sources_text = ", ".join(rag_result["sources"][:3]) + return f"{rag_result['response']}\n\n[Source: Local Documents, Confidence: {rag_confidence:.2f}, Sources: {sources_text}]" + + # STEP 3: Try web search + web_result, web_confidence = self.web_search(query) + + if web_result and web_confidence >= THRESHOLDS['web_search']: + print("Using web search response (sufficient confidence)") + sources_text = ", ".join(web_result["sources"]) + return f"{web_result['response']}\n\n[Source: Web Search, Confidence: {web_confidence:.2f}, Sources: {sources_text}]" + + # STEP 4: Fall back to direct response with warning + print("No high-confidence source found, using direct response with warning") + return f"{direct_response}\n\n[Warning: Low confidence ({direct_confidence:.2f}). Please verify information.]" + +# Simple API functions +def get_answer(query): + """Get an answer for a query""" + system = AgenticQASystem() + return system.get_answer(query) + +# Main entry point +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="QA System") + + mode_group = parser.add_mutually_exclusive_group(required=True) + mode_group.add_argument("--query", "-q", help="Query to answer") + mode_group.add_argument("--interactive", "-i", action="store_true", help="Run in interactive chat mode") + mode_group.add_argument("--test", "-t", action="store_true", help="Run tests") + + args = parser.parse_args() + + if args.interactive: + # Simple interactive mode without memory + qa_system = AgenticQASystem() + print("=== QA System ===") + print("Type 'exit' or 'quit' to end") + + while True: + user_input = input("\nYou: ") + if not user_input.strip(): + continue + + if user_input.lower() in ['exit', 'quit', 'خروج']: + break + + response = qa_system.get_answer(user_input) + print(f"\nBot: {response}") + elif args.query: + qa_system = AgenticQASystem() + print(qa_system.get_answer(args.query)) + elif args.test: + print("Running tests...") \ No newline at end of file