Add enhanced_combined.py
This commit is contained in:
commit
fd7ba8b68c
460
enhanced_combined.py
Normal file
460
enhanced_combined.py
Normal file
@ -0,0 +1,460 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import ssl
|
||||
import argparse
|
||||
import requests
|
||||
from bs4 import BeautifulSoup
|
||||
from urllib.parse import quote
|
||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
from langchain_community.vectorstores import Chroma
|
||||
from langchain_core.documents import Document
|
||||
import traceback
|
||||
|
||||
# Disable SSL warnings and proxy settings
|
||||
ssl._create_default_https_context = ssl._create_unverified_context
|
||||
requests.packages.urllib3.disable_warnings()
|
||||
|
||||
def clear_proxy_settings():
|
||||
"""Remove proxy environment variables that might cause connection issues."""
|
||||
for var in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]:
|
||||
if var in os.environ:
|
||||
print(f"Removing proxy env var: {var}")
|
||||
del os.environ[var]
|
||||
|
||||
# Run at module load time
|
||||
clear_proxy_settings()
|
||||
|
||||
# Configuration
|
||||
DOCUMENT_PATHS = [
|
||||
r'doc1.txt',
|
||||
r'doc2.txt',
|
||||
r'doc3.txt',
|
||||
r'doc4.txt',
|
||||
r'doc5.txt',
|
||||
r'doc6.txt'
|
||||
]
|
||||
EMBEDDING_MODEL = 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2'
|
||||
LLM_MODEL = 'gemma3'
|
||||
CHUNK_SIZE = 1000
|
||||
OVERLAP = 200
|
||||
CHROMA_PERSIST_DIR = 'chroma_db'
|
||||
|
||||
# Confidence thresholds
|
||||
THRESHOLDS = {
|
||||
'direct_answer': 0.7,
|
||||
'rag_confidence': 0.6,
|
||||
'web_search': 0.5
|
||||
}
|
||||
|
||||
def query_llm(prompt, model='gemma3'):
|
||||
"""Query the LLM model directly using Ollama API."""
|
||||
try:
|
||||
ollama_endpoint = "http://localhost:11434/api/generate"
|
||||
payload = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"stream": False
|
||||
}
|
||||
response = requests.post(ollama_endpoint, json=payload)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return result.get('response', '')
|
||||
else:
|
||||
print(f"Ollama API error: {response.status_code}")
|
||||
return f"Error calling Ollama API: {response.status_code}"
|
||||
except Exception as e:
|
||||
print(f"Error querying LLM: {e}")
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
class BM25Retriever:
|
||||
"""BM25 retriever implementation for text similarity search"""
|
||||
|
||||
@classmethod
|
||||
def from_documents(cls, documents):
|
||||
"""Create a BM25 retriever from documents"""
|
||||
retriever = cls()
|
||||
retriever.documents = documents
|
||||
retriever.k = 4
|
||||
return retriever
|
||||
|
||||
def get_relevant_documents(self, query):
|
||||
"""Get relevant documents using BM25 algorithm"""
|
||||
# Simple BM25-like implementation
|
||||
scores = []
|
||||
query_terms = set(re.findall(r'\b\w+\b', query.lower()))
|
||||
|
||||
for doc in self.documents:
|
||||
doc_terms = set(re.findall(r'\b\w+\b', doc.page_content.lower()))
|
||||
# Calculate term overlap as a simple approximation of BM25
|
||||
overlap = len(query_terms.intersection(doc_terms))
|
||||
scores.append((doc, overlap))
|
||||
|
||||
# Sort by score and return top k
|
||||
sorted_docs = [doc for doc, score in sorted(scores, key=lambda x: x[1], reverse=True)]
|
||||
return sorted_docs[:self.k]
|
||||
|
||||
class HybridRetriever:
|
||||
"""Hybrid retriever combining BM25 and vector search with configurable weights"""
|
||||
|
||||
def __init__(self, vector_retriever, bm25_retriever, vector_weight=0.3):
|
||||
"""Initialize with separate retrievers and weights"""
|
||||
self._vector_retriever = vector_retriever
|
||||
self._bm25_retriever = bm25_retriever
|
||||
self._vector_weight = vector_weight
|
||||
self._bm25_weight = 1.0 - vector_weight
|
||||
|
||||
def get_relevant_documents(self, query):
|
||||
"""Get relevant documents using weighted combination of retrievers"""
|
||||
try:
|
||||
# Get results from both retrievers
|
||||
vector_docs = self._vector_retriever.get_relevant_documents(query)
|
||||
bm25_docs = self._bm25_retriever.get_relevant_documents(query)
|
||||
|
||||
# Create dictionary to track unique documents and their scores
|
||||
doc_dict = {}
|
||||
|
||||
# Add vector docs with their weights
|
||||
for i, doc in enumerate(vector_docs):
|
||||
# Score based on position (inverse rank)
|
||||
score = (len(vector_docs) - i) * self._vector_weight
|
||||
doc_id = doc.page_content[:50] # Use first 50 chars as a simple ID
|
||||
if doc_id in doc_dict:
|
||||
doc_dict[doc_id]["score"] += score
|
||||
else:
|
||||
doc_dict[doc_id] = {"doc": doc, "score": score}
|
||||
|
||||
# Add BM25 docs with their weights
|
||||
for i, doc in enumerate(bm25_docs):
|
||||
# Score based on position (inverse rank)
|
||||
score = (len(bm25_docs) - i) * self._bm25_weight
|
||||
doc_id = doc.page_content[:50] # Use first 50 chars as a simple ID
|
||||
if doc_id in doc_dict:
|
||||
doc_dict[doc_id]["score"] += score
|
||||
else:
|
||||
doc_dict[doc_id] = {"doc": doc, "score": score}
|
||||
|
||||
# Sort by combined score (highest first)
|
||||
sorted_docs = sorted(doc_dict.values(), key=lambda x: x["score"], reverse=True)
|
||||
|
||||
# Return just the document objects
|
||||
return [item["doc"] for item in sorted_docs]
|
||||
except Exception as e:
|
||||
print(f"Error in hybrid retrieval: {e}")
|
||||
return []
|
||||
|
||||
class AgenticQASystem:
|
||||
"""QA system implementing the specified architecture"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the QA system with retrievers"""
|
||||
# Load embeddings
|
||||
self.embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
|
||||
# Load documents and retrievers
|
||||
self.documents = self.load_documents()
|
||||
self.retriever = self.initialize_retriever()
|
||||
|
||||
def load_documents(self):
|
||||
"""Load documents from configured paths with sliding window chunking"""
|
||||
print("Loading documents...")
|
||||
docs = []
|
||||
for path in DOCUMENT_PATHS:
|
||||
try:
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
text = re.sub(r'\s+', ' ', f.read()).strip()
|
||||
# Sliding window chunking
|
||||
chunks = [text[i:i+CHUNK_SIZE] for i in range(0, len(text), CHUNK_SIZE - OVERLAP)]
|
||||
for chunk in chunks:
|
||||
docs.append(Document(
|
||||
page_content=chunk,
|
||||
metadata={"source": os.path.basename(path)}
|
||||
))
|
||||
except Exception as e:
|
||||
print(f"Error loading document {path}: {e}")
|
||||
print(f"Loaded {len(docs)} document chunks")
|
||||
return docs
|
||||
|
||||
def initialize_retriever(self):
|
||||
"""Initialize the hybrid retriever with BM25 and direct Chroma queries"""
|
||||
if not self.documents:
|
||||
print("No documents loaded, retriever initialization failed")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Create BM25 retriever
|
||||
bm25_retriever = BM25Retriever.from_documents(self.documents)
|
||||
bm25_retriever.k = 4 # Top k results to retrieve
|
||||
|
||||
# Initialize vector store with KNN search
|
||||
import shutil
|
||||
if os.path.exists(CHROMA_PERSIST_DIR):
|
||||
print(f"Removing existing Chroma DB to prevent dimension mismatch")
|
||||
shutil.rmtree(CHROMA_PERSIST_DIR)
|
||||
|
||||
# Create vector store directly from Chroma
|
||||
print("Creating vector store...")
|
||||
vector_store = Chroma.from_documents(
|
||||
documents=self.documents,
|
||||
embedding=self.embeddings,
|
||||
persist_directory=CHROMA_PERSIST_DIR
|
||||
)
|
||||
|
||||
vector_retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 4})
|
||||
print(f"Vector retriever created: {type(vector_retriever)}")
|
||||
|
||||
# Create hybrid retriever - BM25 (70%) and Vector (30%)
|
||||
print("Creating hybrid retriever")
|
||||
hybrid_retriever = HybridRetriever(vector_retriever, bm25_retriever, vector_weight=0.3)
|
||||
print("Hybrid retriever initialized successfully")
|
||||
return hybrid_retriever
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error initializing retriever: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
def estimate_confidence(self, text, query, context=None):
|
||||
"""Estimate confidence of response"""
|
||||
# Start with baseline confidence
|
||||
confidence = 0.5
|
||||
|
||||
# Check for uncertainty markers
|
||||
uncertainty_phrases = [
|
||||
"نمیدانم", "مطمئن نیستم", "ممکن است", "شاید", "احتمالاً",
|
||||
"فکر میکنم", "به نظر میرسد"
|
||||
]
|
||||
|
||||
if any(phrase in text.lower() for phrase in uncertainty_phrases):
|
||||
confidence -= 0.2
|
||||
|
||||
# Check for question relevance
|
||||
query_words = set(re.findall(r'\b\w+\b', query.lower()))
|
||||
text_words = set(re.findall(r'\b\w+\b', text.lower()))
|
||||
|
||||
# Calculate overlap between query and response
|
||||
if query_words:
|
||||
overlap_ratio = len(query_words.intersection(text_words)) / len(query_words)
|
||||
if overlap_ratio > 0.5:
|
||||
confidence += 0.2
|
||||
elif overlap_ratio < 0.2:
|
||||
confidence -= 0.2
|
||||
|
||||
# If context provided, check context relevance
|
||||
if context:
|
||||
context_words = set(re.findall(r'\b\w+\b', context.lower()))
|
||||
if context_words:
|
||||
context_overlap = len(context_words.intersection(text_words)) / len(context_words)
|
||||
if context_overlap > 0.3:
|
||||
confidence += 0.2
|
||||
else:
|
||||
confidence -= 0.1
|
||||
|
||||
# Ensure confidence is within [0,1]
|
||||
return max(0.0, min(1.0, confidence))
|
||||
|
||||
def check_direct_knowledge(self, query):
|
||||
"""Check if the LLM can answer directly from its knowledge"""
|
||||
print("Checking LLM's direct knowledge...")
|
||||
prompt = f"""به این سوال با استفاده از دانش خود پاسخ دهید. فقط به زبان فارسی پاسخ دهید.
|
||||
|
||||
سوال: {query}
|
||||
|
||||
پاسخ فارسی:"""
|
||||
|
||||
response = query_llm(prompt, model=LLM_MODEL)
|
||||
confidence = self.estimate_confidence(response, query)
|
||||
print(f"LLM direct knowledge confidence: {confidence:.2f}")
|
||||
|
||||
return response, confidence
|
||||
|
||||
def rag_query(self, query):
|
||||
"""Use RAG to retrieve and generate answer"""
|
||||
if not self.retriever:
|
||||
print("Retriever not initialized, skipping RAG")
|
||||
return None, 0.0
|
||||
|
||||
print("Retrieving documents for RAG...")
|
||||
# Retrieve relevant documents
|
||||
docs = self.retriever.get_relevant_documents(query)
|
||||
|
||||
if not docs:
|
||||
print("No relevant documents found")
|
||||
return None, 0.0
|
||||
|
||||
print(f"Retrieved {len(docs)} relevant documents")
|
||||
|
||||
# Prepare context
|
||||
context = "\n\n".join([doc.page_content for doc in docs])
|
||||
sources = [doc.metadata.get("source", "Unknown") for doc in docs]
|
||||
|
||||
# Query LLM with context
|
||||
prompt = f"""با توجه به اطلاعات زیر، به سوال پاسخ دهید. فقط به زبان فارسی پاسخ دهید.
|
||||
|
||||
اطلاعات:
|
||||
{context}
|
||||
|
||||
سوال: {query}
|
||||
|
||||
پاسخ فارسی:"""
|
||||
|
||||
response = query_llm(prompt, model=LLM_MODEL)
|
||||
confidence = self.estimate_confidence(response, query, context)
|
||||
print(f"RAG confidence: {confidence:.2f}")
|
||||
|
||||
return {
|
||||
"response": response,
|
||||
"confidence": confidence,
|
||||
"sources": list(set(sources))
|
||||
}, confidence
|
||||
|
||||
def web_search(self, query):
|
||||
"""Search the web for an answer"""
|
||||
print("Searching web for answer...")
|
||||
# Search DuckDuckGo
|
||||
search_url = f"https://html.duckduckgo.com/html/?q={quote(query)}"
|
||||
response = requests.get(search_url, verify=False, timeout=10)
|
||||
|
||||
if response.status_code != 200:
|
||||
print(f"Error searching web: HTTP {response.status_code}")
|
||||
return None, 0.0
|
||||
|
||||
# Parse results
|
||||
soup = BeautifulSoup(response.text, 'html.parser')
|
||||
results = []
|
||||
|
||||
for element in soup.select('.result__url, .result__a')[:4]:
|
||||
href = element.get('href') if 'href' in element.attrs else None
|
||||
|
||||
if href and not href.startswith('/') and (href.startswith('http://') or href.startswith('https://')):
|
||||
results.append(href)
|
||||
elif not href and element.find('a') and 'href' in element.find('a').attrs:
|
||||
href = element.find('a')['href']
|
||||
if href and not href.startswith('/'):
|
||||
results.append(href)
|
||||
|
||||
if not results:
|
||||
print("No web results found")
|
||||
return None, 0.0
|
||||
|
||||
# Crawl top results
|
||||
web_content = []
|
||||
for url in results[:3]:
|
||||
try:
|
||||
headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
|
||||
page = requests.get(url, headers=headers, timeout=10, verify=False)
|
||||
page.raise_for_status()
|
||||
|
||||
soup = BeautifulSoup(page.text, 'html.parser')
|
||||
|
||||
# Remove non-content elements
|
||||
for tag in ['script', 'style', 'nav', 'footer', 'header']:
|
||||
for element in soup.find_all(tag):
|
||||
element.decompose()
|
||||
|
||||
# Get paragraphs
|
||||
paragraphs = [p.get_text(strip=True) for p in soup.find_all('p')
|
||||
if len(p.get_text(strip=True)) > 20]
|
||||
|
||||
if paragraphs:
|
||||
web_content.append(f"[Source: {url}] " + " ".join(paragraphs[:5]))
|
||||
except Exception as e:
|
||||
print(f"Error crawling {url}: {e}")
|
||||
|
||||
if not web_content:
|
||||
print("No useful content found from web results")
|
||||
return None, 0.0
|
||||
|
||||
# Query LLM with web content
|
||||
context = "\n\n".join(web_content)
|
||||
prompt = f"""با توجه به اطلاعات زیر که از وب بدست آمده، به سوال پاسخ دهید. فقط به زبان فارسی پاسخ دهید.
|
||||
|
||||
اطلاعات:
|
||||
{context}
|
||||
|
||||
سوال: {query}
|
||||
|
||||
پاسخ فارسی:"""
|
||||
|
||||
response = query_llm(prompt, model=LLM_MODEL)
|
||||
confidence = self.estimate_confidence(response, query, context)
|
||||
print(f"Web search confidence: {confidence:.2f}")
|
||||
|
||||
return {
|
||||
"response": response,
|
||||
"confidence": confidence,
|
||||
"sources": results[:3]
|
||||
}, confidence
|
||||
|
||||
def get_answer(self, query):
|
||||
"""Main method to get an answer following the specified architecture"""
|
||||
print(f"Processing query: {query}")
|
||||
|
||||
# STEP 1: Try direct LLM knowledge
|
||||
direct_response, direct_confidence = self.check_direct_knowledge(query)
|
||||
|
||||
if direct_confidence >= THRESHOLDS['direct_answer']:
|
||||
print("Using direct LLM knowledge (high confidence)")
|
||||
return f"{direct_response}\n\n[Source: LLM Knowledge, Confidence: {direct_confidence:.2f}]"
|
||||
|
||||
# STEP 2: Try RAG with local documents
|
||||
rag_result, rag_confidence = self.rag_query(query)
|
||||
|
||||
if rag_result and rag_confidence >= THRESHOLDS['rag_confidence']:
|
||||
print("Using RAG response (sufficient confidence)")
|
||||
sources_text = ", ".join(rag_result["sources"][:3])
|
||||
return f"{rag_result['response']}\n\n[Source: Local Documents, Confidence: {rag_confidence:.2f}, Sources: {sources_text}]"
|
||||
|
||||
# STEP 3: Try web search
|
||||
web_result, web_confidence = self.web_search(query)
|
||||
|
||||
if web_result and web_confidence >= THRESHOLDS['web_search']:
|
||||
print("Using web search response (sufficient confidence)")
|
||||
sources_text = ", ".join(web_result["sources"])
|
||||
return f"{web_result['response']}\n\n[Source: Web Search, Confidence: {web_confidence:.2f}, Sources: {sources_text}]"
|
||||
|
||||
# STEP 4: Fall back to direct response with warning
|
||||
print("No high-confidence source found, using direct response with warning")
|
||||
return f"{direct_response}\n\n[Warning: Low confidence ({direct_confidence:.2f}). Please verify information.]"
|
||||
|
||||
# Simple API functions
|
||||
def get_answer(query):
|
||||
"""Get an answer for a query"""
|
||||
system = AgenticQASystem()
|
||||
return system.get_answer(query)
|
||||
|
||||
# Main entry point
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="QA System")
|
||||
|
||||
mode_group = parser.add_mutually_exclusive_group(required=True)
|
||||
mode_group.add_argument("--query", "-q", help="Query to answer")
|
||||
mode_group.add_argument("--interactive", "-i", action="store_true", help="Run in interactive chat mode")
|
||||
mode_group.add_argument("--test", "-t", action="store_true", help="Run tests")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.interactive:
|
||||
# Simple interactive mode without memory
|
||||
qa_system = AgenticQASystem()
|
||||
print("=== QA System ===")
|
||||
print("Type 'exit' or 'quit' to end")
|
||||
|
||||
while True:
|
||||
user_input = input("\nYou: ")
|
||||
if not user_input.strip():
|
||||
continue
|
||||
|
||||
if user_input.lower() in ['exit', 'quit', 'خروج']:
|
||||
break
|
||||
|
||||
response = qa_system.get_answer(user_input)
|
||||
print(f"\nBot: {response}")
|
||||
elif args.query:
|
||||
qa_system = AgenticQASystem()
|
||||
print(qa_system.get_answer(args.query))
|
||||
elif args.test:
|
||||
print("Running tests...")
|
||||
Loading…
x
Reference in New Issue
Block a user