1022 lines
41 KiB
Python
1022 lines
41 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
import os
|
|
import re
|
|
import json
|
|
import time
|
|
import requests
|
|
import argparse
|
|
import numpy as np
|
|
import traceback
|
|
import datetime
|
|
from urllib.parse import urljoin, urlparse, quote
|
|
import logging
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from bs4 import BeautifulSoup
|
|
from langchain_community.vectorstores import Chroma
|
|
from langchain_core.documents import Document
|
|
from langchain_core.retrievers import BaseRetriever
|
|
from langchain_community.retrievers import BM25Retriever
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings
|
|
|
|
# Disable proxy settings that might cause connection issues
|
|
def clear_proxy_settings():
|
|
"""Remove proxy environment variables that might cause connection issues."""
|
|
for var in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]:
|
|
if var in os.environ:
|
|
print(f"Removing proxy env var: {var}")
|
|
del os.environ[var]
|
|
|
|
# Run at module load time
|
|
clear_proxy_settings()
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.StreamHandler(),
|
|
logging.FileHandler("organic_chemistry_crawler.log")
|
|
]
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Configuration
|
|
class Config:
|
|
# Search settings
|
|
SEARCH_ENGINE = "combined" # Options: "duckduckgo", "arxiv", "combined"
|
|
NUM_SEARCH_RESULTS = 10
|
|
|
|
# Crawling settings
|
|
MAX_DEPTH = 1 # How deep to follow links from initial pages
|
|
MAX_LINKS_PER_PAGE = 5 # Max links to follow from each page
|
|
MAX_TOTAL_PAGES = 20 # Max total pages to crawl
|
|
REQUEST_TIMEOUT = 10 # Seconds
|
|
REQUEST_DELAY = 1 # Seconds between requests
|
|
|
|
# Content extraction settings
|
|
MIN_CONTENT_LENGTH = 100 # Minimum characters for content to be considered valid
|
|
|
|
# RAG settings
|
|
CHUNK_SIZE = 1000
|
|
OVERLAP = 200
|
|
CONFIDENCE_THRESHOLD = 0.6
|
|
|
|
# Embedding and LLM settings
|
|
EMBEDDING_MODEL = 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2'
|
|
LLM_MODEL = 'gemma3'
|
|
CHROMA_PERSIST_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'chroma_db')
|
|
SEMANTIC_MEMORY_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'semantic_memory')
|
|
|
|
# Confidence thresholds
|
|
THRESHOLDS = {
|
|
'direct_knowledge': 0.6,
|
|
'rag': 0.6,
|
|
'web_search': 0.5,
|
|
'memory_match': 0.15
|
|
}
|
|
|
|
# Output settings
|
|
OUTPUT_LANGUAGE = "fa" # Options: "fa" (Farsi), "en" (English)
|
|
|
|
# Organic chemistry specific sites to prioritize
|
|
PRIORITY_DOMAINS = [
|
|
"pubchem.ncbi.nlm.nih.gov",
|
|
"chemistrysteps.com",
|
|
"masterorganicchemistry.com",
|
|
"chemguide.co.uk",
|
|
"organic-chemistry.org",
|
|
"chemistryworld.com",
|
|
"chemspider.com",
|
|
"organicchemistrytutor.com",
|
|
"chem.libretexts.org",
|
|
"chemhelper.com",
|
|
"arxiv.org",
|
|
"jahaneshimi.com",
|
|
"blog.faradars.org",
|
|
"en.wikipedia.org",
|
|
"fa.wikipedia.org"
|
|
]
|
|
|
|
class OrganicChemistryCrawler:
|
|
"""Crawler specialized for organic chemistry information with enhanced RAG capabilities"""
|
|
|
|
def __init__(self, config=None):
|
|
"""Initialize the crawler with configuration"""
|
|
self.config = config or Config()
|
|
self.visited_urls = set()
|
|
self.crawled_content = {} # url -> content
|
|
self.url_queue = []
|
|
|
|
# Initialize semantic memory
|
|
os.makedirs(self.config.SEMANTIC_MEMORY_DIR, exist_ok=True)
|
|
self.semantic_memory = SemanticMemory(self.config.SEMANTIC_MEMORY_DIR)
|
|
|
|
# Initialize embeddings and vector store
|
|
try:
|
|
self.embeddings = HuggingFaceEmbeddings(model_name=self.config.EMBEDDING_MODEL)
|
|
logging.info(f"Initialized embeddings with model: {self.config.EMBEDDING_MODEL}")
|
|
except Exception as e:
|
|
logging.error(f"Error initializing embeddings: {e}")
|
|
self.embeddings = None
|
|
|
|
# RAG components will be initialized after crawling
|
|
self.retriever = None
|
|
self.vector_store = None
|
|
|
|
def search_duckduckgo(self, query):
|
|
"""Search DuckDuckGo for organic chemistry information"""
|
|
# Add organic chemistry context to the query
|
|
if "organic chemistry" not in query.lower():
|
|
search_query = f"{query} organic chemistry"
|
|
else:
|
|
search_query = query
|
|
|
|
logging.info(f"Searching DuckDuckGo for: {search_query}")
|
|
|
|
# DuckDuckGo doesn't have an official API, so we'll use their HTML search
|
|
headers = {
|
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
|
}
|
|
|
|
# URL encode the query
|
|
encoded_query = search_query.replace(' ', '+')
|
|
search_url = f"https://duckduckgo.com/html/?q={encoded_query}"
|
|
|
|
try:
|
|
response = requests.get(search_url, headers=headers, timeout=self.config.REQUEST_TIMEOUT)
|
|
response.raise_for_status()
|
|
|
|
# Use BeautifulSoup for more reliable parsing
|
|
soup = BeautifulSoup(response.text, 'html.parser')
|
|
result_urls = []
|
|
|
|
# Get results from the result items
|
|
for result in soup.select('.result__a'):
|
|
href = result.get('href')
|
|
if href and href.startswith('http'):
|
|
result_urls.append(href)
|
|
|
|
if not result_urls:
|
|
# Fallback to regex pattern
|
|
url_pattern = r'<a[^>]*class="[^"]*result__a[^"]*"[^>]*href="([^"]+)"'
|
|
result_urls = re.findall(url_pattern, response.text)
|
|
|
|
results = []
|
|
for href in result_urls:
|
|
if href and href.startswith('http'):
|
|
results.append(href)
|
|
|
|
# Prioritize results from known chemistry domains
|
|
prioritized_results = []
|
|
other_results = []
|
|
|
|
for url in results:
|
|
domain = urlparse(url).netloc
|
|
if any(priority_domain in domain for priority_domain in self.config.PRIORITY_DOMAINS):
|
|
prioritized_results.append(url)
|
|
else:
|
|
other_results.append(url)
|
|
|
|
# Combine prioritized and other results
|
|
combined_results = []
|
|
seen_urls = set()
|
|
|
|
# First add DuckDuckGo results
|
|
for url in prioritized_results:
|
|
if url not in seen_urls:
|
|
combined_results.append(url)
|
|
seen_urls.add(url)
|
|
|
|
# Then add other results
|
|
for url in other_results:
|
|
if url not in seen_urls:
|
|
combined_results.append(url)
|
|
seen_urls.add(url)
|
|
|
|
return combined_results[:self.config.NUM_SEARCH_RESULTS]
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error searching DuckDuckGo: {e}")
|
|
return []
|
|
|
|
def search_arxiv(self, query):
|
|
"""Search arXiv for organic chemistry papers"""
|
|
logging.info(f"Searching arXiv for: {query}")
|
|
|
|
# Add organic chemistry context to the query
|
|
if "organic chemistry" not in query.lower():
|
|
search_query = f"{query} organic chemistry"
|
|
else:
|
|
search_query = query
|
|
|
|
# URL encode the query
|
|
encoded_query = search_query.replace(' ', '+')
|
|
|
|
# arXiv API endpoint
|
|
search_url = f"http://export.arxiv.org/api/query?search_query=all:{encoded_query}&start=0&max_results={self.config.NUM_SEARCH_RESULTS}"
|
|
|
|
try:
|
|
response = requests.get(search_url, timeout=self.config.REQUEST_TIMEOUT)
|
|
response.raise_for_status()
|
|
|
|
# Parse the XML response using regex
|
|
xml = response.text
|
|
|
|
# Extract entry links using regex
|
|
entry_pattern = r'<entry>.*?<id>(.*?)</id>.*?</entry>'
|
|
entries = re.findall(entry_pattern, xml, re.DOTALL)
|
|
|
|
results = []
|
|
for entry_id in entries:
|
|
if entry_id:
|
|
results.append(entry_id)
|
|
|
|
return results
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error searching arXiv: {e}")
|
|
return []
|
|
|
|
def search(self, query):
|
|
"""Search for organic chemistry information using configured search engine"""
|
|
if self.config.SEARCH_ENGINE == "duckduckgo":
|
|
return self.search_duckduckgo(query)
|
|
elif self.config.SEARCH_ENGINE == "arxiv":
|
|
return self.search_arxiv(query)
|
|
elif self.config.SEARCH_ENGINE == "combined":
|
|
# Use both search engines and combine results
|
|
duckduckgo_results = self.search_duckduckgo(query)
|
|
arxiv_results = self.search_arxiv(query)
|
|
|
|
# Combine and deduplicate results
|
|
combined_results = []
|
|
seen_urls = set()
|
|
|
|
# First add DuckDuckGo results
|
|
for url in duckduckgo_results:
|
|
if url not in seen_urls:
|
|
combined_results.append(url)
|
|
seen_urls.add(url)
|
|
|
|
# Then add arXiv results
|
|
for url in arxiv_results:
|
|
if url not in seen_urls:
|
|
combined_results.append(url)
|
|
seen_urls.add(url)
|
|
|
|
return combined_results[:self.config.NUM_SEARCH_RESULTS]
|
|
else:
|
|
logging.error(f"Unknown search engine: {self.config.SEARCH_ENGINE}")
|
|
return []
|
|
|
|
def extract_content(self, html, url):
|
|
"""Extract relevant content from HTML using BeautifulSoup"""
|
|
soup = BeautifulSoup(html, 'html.parser')
|
|
|
|
# Remove script, style, and nav elements
|
|
for tag in ['script', 'style', 'nav', 'header', 'footer']:
|
|
for element in soup.find_all(tag):
|
|
element.decompose()
|
|
|
|
# Extract title
|
|
title = soup.title.text.strip() if soup.title else urlparse(url).path
|
|
|
|
# Try to find main content
|
|
content = ""
|
|
|
|
# Try article tags first
|
|
article_content = []
|
|
for article in soup.find_all('article'):
|
|
text = article.get_text(strip=True)
|
|
if len(text) > self.config.MIN_CONTENT_LENGTH:
|
|
article_content.append(text)
|
|
|
|
if article_content:
|
|
content = "\n\n".join(article_content)
|
|
else:
|
|
# Try content divs
|
|
for div in soup.find_all('div', class_=lambda c: c and any(term in str(c).lower() for term in ['content', 'main', 'article', 'body'])):
|
|
text = div.get_text(strip=True)
|
|
if len(text) > self.config.MIN_CONTENT_LENGTH:
|
|
content += text + "\n\n"
|
|
|
|
# If still no content, extract all paragraphs
|
|
if not content or len(content) < self.config.MIN_CONTENT_LENGTH:
|
|
paragraphs = [p.get_text(strip=True) for p in soup.find_all('p') if len(p.get_text(strip=True)) > 20]
|
|
if paragraphs:
|
|
content = "\n\n".join(paragraphs)
|
|
|
|
# Clean up content
|
|
content = re.sub(r'\s+', ' ', content).strip()
|
|
|
|
return {
|
|
"title": title,
|
|
"content": content,
|
|
"url": url
|
|
}
|
|
|
|
def extract_links(self, html, base_url):
|
|
"""Extract links from HTML to follow"""
|
|
soup = BeautifulSoup(html, 'html.parser')
|
|
links = []
|
|
|
|
for a_tag in soup.find_all('a', href=True):
|
|
href = a_tag['href']
|
|
|
|
# Skip empty links, anchors, or javascript
|
|
if not href or href.startswith('#') or href.startswith('javascript:'):
|
|
continue
|
|
|
|
# Convert relative URLs to absolute
|
|
absolute_url = urljoin(base_url, href)
|
|
|
|
# Skip non-HTTP links
|
|
if not absolute_url.startswith(('http://', 'https://')):
|
|
continue
|
|
|
|
# Skip already visited URLs
|
|
if absolute_url in self.visited_urls:
|
|
continue
|
|
|
|
# Prioritize chemistry domains
|
|
domain = urlparse(absolute_url).netloc
|
|
if any(priority_domain in domain for priority_domain in self.config.PRIORITY_DOMAINS):
|
|
links.insert(0, absolute_url) # Add to beginning of list
|
|
else:
|
|
links.append(absolute_url)
|
|
|
|
# Return limited number of links
|
|
return links[:self.config.MAX_LINKS_PER_PAGE]
|
|
|
|
def crawl_url(self, url, depth=0):
|
|
"""Crawl a single URL and extract content"""
|
|
if url in self.visited_urls or len(self.crawled_content) >= self.config.MAX_TOTAL_PAGES:
|
|
return
|
|
|
|
logging.info(f"Crawling: {url} (depth {depth})")
|
|
self.visited_urls.add(url)
|
|
|
|
try:
|
|
headers = {
|
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
|
}
|
|
response = requests.get(url, headers=headers, timeout=self.config.REQUEST_TIMEOUT)
|
|
response.raise_for_status()
|
|
|
|
# Skip non-HTML content
|
|
content_type = response.headers.get('Content-Type', '')
|
|
if 'text/html' not in content_type and 'application/xhtml+xml' not in content_type:
|
|
logging.info(f"Skipping non-HTML content: {url} ({content_type})")
|
|
return
|
|
|
|
# Extract content
|
|
content_data = self.extract_content(response.text, url)
|
|
|
|
# Only save if we have meaningful content
|
|
if len(content_data["content"]) > self.config.MIN_CONTENT_LENGTH:
|
|
self.crawled_content[url] = content_data
|
|
|
|
# Follow links if we haven't reached max depth
|
|
if depth < self.config.MAX_DEPTH:
|
|
links = self.extract_links(response.text, url)
|
|
for link in links:
|
|
if link not in self.visited_urls:
|
|
self.url_queue.append((link, depth + 1))
|
|
|
|
# Respect rate limits
|
|
time.sleep(self.config.REQUEST_DELAY)
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error crawling {url}: {e}")
|
|
|
|
def process_queue(self):
|
|
"""Process the URL queue with multithreading"""
|
|
with ThreadPoolExecutor(max_workers=5) as executor:
|
|
while self.url_queue and len(self.crawled_content) < self.config.MAX_TOTAL_PAGES:
|
|
# Get a batch of URLs to process
|
|
batch = []
|
|
while self.url_queue and len(batch) < 5:
|
|
batch.append(self.url_queue.pop(0))
|
|
|
|
# Process the batch
|
|
futures = [executor.submit(self.crawl_url, url, depth) for url, depth in batch]
|
|
for future in futures:
|
|
future.result() # Wait for completion
|
|
|
|
def crawl(self, query):
|
|
"""Search and crawl for information about the query"""
|
|
# Step 1: Search for initial URLs
|
|
initial_urls = self.search(query)
|
|
if not initial_urls:
|
|
logging.warning(f"No search results found for query: {query}")
|
|
return {}
|
|
|
|
# Step 2: Initialize crawling queue
|
|
self.url_queue = [(url, 0) for url in initial_urls]
|
|
self.visited_urls = set()
|
|
self.crawled_content = {}
|
|
|
|
# Step 3: Process the queue
|
|
self.process_queue()
|
|
|
|
# Step 4: Return the crawled content
|
|
logging.info(f"Crawling complete. Found {len(self.crawled_content)} pages with content.")
|
|
return self.crawled_content
|
|
|
|
def chunk_text(self, text, chunk_size=None, overlap=None):
|
|
"""Split text into chunks with overlap"""
|
|
if chunk_size is None:
|
|
chunk_size = self.config.CHUNK_SIZE
|
|
if overlap is None:
|
|
overlap = self.config.OVERLAP
|
|
|
|
# If text is shorter than chunk size, return as is
|
|
if len(text) <= chunk_size:
|
|
return [text]
|
|
|
|
chunks = []
|
|
start = 0
|
|
|
|
while start < len(text):
|
|
# Get chunk of specified size
|
|
end = start + chunk_size
|
|
|
|
# Adjust end to avoid cutting words
|
|
if end < len(text):
|
|
# Try to find a space to break at
|
|
while end > start and text[end] != ' ':
|
|
end -= 1
|
|
if end == start: # If no space found, use the original end
|
|
end = start + chunk_size
|
|
|
|
# Add chunk to list
|
|
chunks.append(text[start:end])
|
|
|
|
# Move start position for next chunk, considering overlap
|
|
start = end - overlap
|
|
|
|
return chunks
|
|
|
|
def prepare_documents(self):
|
|
"""Prepare crawled content as documents for RAG using LangChain Document format"""
|
|
documents = []
|
|
|
|
for url, data in self.crawled_content.items():
|
|
content = data["content"]
|
|
title = data["title"]
|
|
|
|
# Chunk the content
|
|
chunks = self.chunk_text(content)
|
|
|
|
# Create documents from chunks
|
|
for i, chunk in enumerate(chunks):
|
|
doc = Document(
|
|
page_content=chunk,
|
|
metadata={
|
|
"source": url,
|
|
"title": title,
|
|
"chunk": i + 1,
|
|
"total_chunks": len(chunks)
|
|
}
|
|
)
|
|
documents.append(doc)
|
|
|
|
return documents
|
|
|
|
def initialize_retriever(self, documents):
|
|
"""Initialize the hybrid retriever with vector search and BM25"""
|
|
if not documents or not self.embeddings:
|
|
logging.error("No documents or embeddings available for retriever initialization")
|
|
return None
|
|
|
|
try:
|
|
# Create BM25 retriever
|
|
bm25_retriever = BM25Retriever.from_documents(documents)
|
|
bm25_retriever.k = 5 # Top k results to retrieve
|
|
|
|
# Initialize or recreate vector store
|
|
if os.path.exists(self.config.CHROMA_PERSIST_DIR):
|
|
import shutil
|
|
logging.info("Removing existing Chroma DB to prevent dimension mismatch")
|
|
shutil.rmtree(self.config.CHROMA_PERSIST_DIR)
|
|
|
|
# Create vector store
|
|
os.makedirs(self.config.CHROMA_PERSIST_DIR, exist_ok=True)
|
|
vector_store = Chroma.from_documents(
|
|
documents=documents,
|
|
embedding=self.embeddings,
|
|
persist_directory=self.config.CHROMA_PERSIST_DIR
|
|
)
|
|
self.vector_store = vector_store
|
|
|
|
# Create vector retriever
|
|
vector_retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 5})
|
|
|
|
# Create hybrid retriever (BM25 70%, Vector 30%)
|
|
hybrid_retriever = HybridRetriever(vector_retriever, bm25_retriever, vector_weight=0.3)
|
|
logging.info("Hybrid retriever initialized successfully")
|
|
|
|
return hybrid_retriever
|
|
except Exception as e:
|
|
logging.error(f"Error initializing retriever: {e}")
|
|
traceback.print_exc()
|
|
return None
|
|
|
|
def check_corrections(self, query):
|
|
"""Check if a correction exists for this query using semantic memory"""
|
|
logging.info("Checking semantic memory for corrections...")
|
|
|
|
# Use semantic memory to find similar queries
|
|
stored_query, answer, similarity = self.semantic_memory.retrieve_memory(query)
|
|
|
|
if stored_query and answer:
|
|
logging.info(f"Found semantic match in memory with similarity: {similarity:.2f}")
|
|
logging.info(f"Original query: '{stored_query}'")
|
|
logging.info(f"Current query: '{query}'")
|
|
return answer, f"Semantic Memory (similarity: {similarity:.2f})"
|
|
|
|
return None, None
|
|
|
|
def estimate_confidence(self, text, query, context=None):
|
|
"""Estimate confidence of response using more sophisticated analysis"""
|
|
# Start with baseline confidence
|
|
confidence = 0.5
|
|
|
|
# Check for uncertainty markers
|
|
uncertainty_phrases = [
|
|
"نمیدانم", "مطمئن نیستم", "ممکن است", "شاید", "احتمالاً",
|
|
"فکر میکنم", "به نظر میرسد", "I don't know", "not sure",
|
|
"might be", "perhaps", "possibly", "it seems"
|
|
]
|
|
|
|
if any(phrase in text.lower() for phrase in uncertainty_phrases):
|
|
confidence -= 0.2
|
|
|
|
# Check for question relevance
|
|
query_words = set(re.findall(r'\b\w+\b', query.lower()))
|
|
text_words = set(re.findall(r'\b\w+\b', text.lower()))
|
|
|
|
# Calculate overlap between query and response
|
|
if query_words:
|
|
overlap_ratio = len(query_words.intersection(text_words)) / len(query_words)
|
|
if overlap_ratio > 0.5:
|
|
confidence += 0.2
|
|
elif overlap_ratio < 0.2:
|
|
confidence -= 0.2
|
|
|
|
# Check for chemistry-specific terms
|
|
chemistry_terms = [
|
|
"molecule", "compound", "reaction", "bond", "carbon", "hydrogen", "oxygen",
|
|
"nitrogen", "synthesis", "organic", "chemical", "structure", "formula",
|
|
"مولکول", "ترکیب", "واکنش", "پیوند", "کربن", "هیدروژن", "اکسیژن",
|
|
"نیتروژن", "سنتز", "آلی", "شیمیایی", "ساختار", "فرمول"
|
|
]
|
|
|
|
chem_term_count = sum(1 for term in chemistry_terms if term.lower() in text.lower())
|
|
term_factor = min(chem_term_count / 5, 1.0) * 0.2
|
|
confidence += term_factor
|
|
|
|
# If context provided, check context relevance
|
|
if context:
|
|
context_words = set(re.findall(r'\b\w+\b', context.lower()))
|
|
if context_words:
|
|
context_overlap = len(context_words.intersection(text_words)) / len(context_words)
|
|
if context_overlap > 0.3:
|
|
confidence += 0.2
|
|
else:
|
|
confidence -= 0.1
|
|
|
|
# Higher confidence for longer, more detailed responses
|
|
if len(text) > 500:
|
|
confidence += 0.1
|
|
elif len(text) < 100:
|
|
confidence -= 0.1
|
|
|
|
# Ensure confidence is within [0,1]
|
|
return max(0.0, min(1.0, confidence))
|
|
|
|
def check_direct_knowledge(self, query):
|
|
"""Check if the LLM can answer directly from its knowledge"""
|
|
logging.info("Checking LLM's direct knowledge...")
|
|
try:
|
|
output_language = "فارسی" if self.config.OUTPUT_LANGUAGE == "fa" else "English"
|
|
|
|
prompt = f"""به این سوال در مورد شیمی آلی با استفاده از دانش خود پاسخ دهید. به زبان {output_language} پاسخ دهید.
|
|
|
|
سوال: {query}
|
|
|
|
پاسخ:"""
|
|
|
|
response = query_llm(prompt, model=self.config.LLM_MODEL)
|
|
confidence = self.estimate_confidence(response, query)
|
|
logging.info(f"LLM direct knowledge confidence: {confidence:.2f}")
|
|
|
|
return response, confidence
|
|
except Exception as e:
|
|
logging.error(f"Error in direct knowledge check: {e}")
|
|
return "Error processing response", 0.0
|
|
|
|
def rag_query(self, query):
|
|
"""Use RAG to retrieve and generate answer based on crawled content"""
|
|
# Prepare documents from crawled content
|
|
documents = self.prepare_documents()
|
|
|
|
if not documents:
|
|
logging.warning("No documents available for RAG")
|
|
if self.config.OUTPUT_LANGUAGE == "fa":
|
|
return "متاسفانه اطلاعاتی در مورد این موضوع پیدا نکردم.", 0.0
|
|
else:
|
|
return "I couldn't find any information about that topic.", 0.0
|
|
|
|
# Initialize retriever if not already done
|
|
if not self.retriever:
|
|
self.retriever = self.initialize_retriever(documents)
|
|
|
|
if not self.retriever:
|
|
logging.error("Failed to initialize retriever")
|
|
if self.config.OUTPUT_LANGUAGE == "fa":
|
|
return "خطا در پردازش اطلاعات رخ داده است.", 0.0
|
|
else:
|
|
return "An error occurred while processing information.", 0.0
|
|
|
|
try:
|
|
# Retrieve relevant documents
|
|
relevant_docs = self.retriever.get_relevant_documents(query)
|
|
|
|
if not relevant_docs:
|
|
logging.warning("No relevant documents found")
|
|
if self.config.OUTPUT_LANGUAGE == "fa":
|
|
return "متاسفانه اطلاعات مرتبطی پیدا نکردم.", 0.0
|
|
else:
|
|
return "I couldn't find any relevant information.", 0.0
|
|
|
|
# Prepare context from retrieved documents
|
|
context = "\n\n".join([
|
|
f"Source: {doc.metadata.get('title')} ({doc.metadata.get('source')})\n{doc.page_content}"
|
|
for doc in relevant_docs[:5]
|
|
])
|
|
|
|
# Extract unique sources
|
|
sources = list(set(doc.metadata.get('source') for doc in relevant_docs[:5]))
|
|
|
|
# Prepare prompt for LLM
|
|
output_language = "فارسی" if self.config.OUTPUT_LANGUAGE == "fa" else "English"
|
|
|
|
prompt = f"""با توجه به اطلاعات زیر، به سوال در مورد شیمی آلی پاسخ دهید. به زبان {output_language} پاسخ دهید.
|
|
|
|
اطلاعات:
|
|
{context}
|
|
|
|
سوال: {query}
|
|
|
|
پاسخ:"""
|
|
|
|
# Query LLM
|
|
response = query_llm(prompt, model=self.config.LLM_MODEL)
|
|
|
|
# Estimate confidence
|
|
confidence = self.estimate_confidence(response, query, context)
|
|
logging.info(f"RAG confidence: {confidence:.2f}")
|
|
|
|
# Add source attribution
|
|
if self.config.OUTPUT_LANGUAGE == "fa":
|
|
response += f"\n\nاین اطلاعات از {len(sources)} منبع گردآوری شده است."
|
|
else:
|
|
response += f"\n\nThis information was compiled from {len(sources)} sources."
|
|
|
|
return response, confidence, sources
|
|
except Exception as e:
|
|
logging.error(f"Error in RAG query: {e}")
|
|
traceback.print_exc()
|
|
if self.config.OUTPUT_LANGUAGE == "fa":
|
|
return "خطا در پردازش اطلاعات رخ داده است.", 0.0, []
|
|
else:
|
|
return "An error occurred while processing information.", 0.0, []
|
|
|
|
def get_answer(self, query):
|
|
"""Main method to get an answer following the agent-based architecture"""
|
|
logging.info(f"Processing query: {query}")
|
|
|
|
# STEP 1: Check corrections memory
|
|
correction, source = self.check_corrections(query)
|
|
if correction:
|
|
return f"{correction}\n\n[Source: {source}]"
|
|
|
|
# STEP 2: Try direct LLM knowledge
|
|
direct_response, direct_confidence = self.check_direct_knowledge(query)
|
|
|
|
if direct_confidence >= self.config.THRESHOLDS['direct_knowledge']:
|
|
logging.info("Using direct LLM knowledge (high confidence)")
|
|
return f"{direct_response}\n\n[Source: LLM Knowledge, Confidence: {direct_confidence:.2f}]"
|
|
|
|
# STEP 3: Crawl and index content if not already done
|
|
if not self.crawled_content:
|
|
self.crawl(query)
|
|
|
|
# STEP 4: Try RAG with crawled documents
|
|
rag_response, rag_confidence, sources = self.rag_query(query)
|
|
|
|
if rag_confidence >= self.config.THRESHOLDS['rag']:
|
|
logging.info("Using RAG response (sufficient confidence)")
|
|
sources_text = ", ".join(sources[:3])
|
|
return f"{rag_response}\n\n[Source: Web Content, Confidence: {rag_confidence:.2f}, Sources: {sources_text}]"
|
|
|
|
# STEP 5: Fall back to direct response with warning
|
|
logging.info("No high-confidence source found, using direct response with warning")
|
|
return f"{direct_response}\n\n[Warning: Low confidence ({direct_confidence:.2f}). Please verify this information.]"
|
|
|
|
def add_correction(self, incorrect_query, correct_answer):
|
|
"""Add a correction to semantic memory"""
|
|
try:
|
|
# Add the correction to semantic memory
|
|
success = self.semantic_memory.add_memory(
|
|
incorrect_query,
|
|
correct_answer,
|
|
{"type": "correction", "timestamp": str(datetime.datetime.now())}
|
|
)
|
|
|
|
if success:
|
|
logging.info(f"Added correction for: '{incorrect_query}'")
|
|
|
|
return success
|
|
except Exception as e:
|
|
logging.error(f"Error adding correction: {e}")
|
|
return False
|
|
|
|
def save_results(self, query, output_file=None, answer=None, confidence=None):
|
|
"""Save crawled results to a JSON file and answer to a text file"""
|
|
# Create a safe filename based on the query
|
|
safe_query = re.sub(r'[^\w\s-]', '', query).strip().lower()
|
|
safe_query = re.sub(r'[-\s]+', '_', safe_query)
|
|
timestamp = int(time.time())
|
|
|
|
# Save crawled content to JSON
|
|
if not output_file:
|
|
output_file = f"organic_chemistry_{safe_query}_{timestamp}.json"
|
|
|
|
with open(output_file, 'w', encoding='utf-8') as f:
|
|
json.dump({
|
|
"query": query,
|
|
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
|
|
"results": self.crawled_content
|
|
}, f, ensure_ascii=False, indent=2)
|
|
|
|
logging.info(f"Results saved to {output_file}")
|
|
|
|
# Save answer to text file if provided
|
|
if answer:
|
|
answer_file = f"organic_chemistry_results.txt"
|
|
|
|
# Append to existing results file
|
|
with open(answer_file, 'a', encoding='utf-8') as f:
|
|
f.write(f"\n{'='*80}\n")
|
|
if self.config.OUTPUT_LANGUAGE == "fa":
|
|
f.write(f"سوال: {query}\n")
|
|
if confidence is not None:
|
|
f.write(f"اطمینان: {confidence:.2f}\n")
|
|
else:
|
|
f.write(f"Query: {query}\n")
|
|
if confidence is not None:
|
|
f.write(f"Confidence: {confidence:.2f}\n")
|
|
f.write(f"\n{answer}\n")
|
|
|
|
logging.info(f"Answer saved to {answer_file}")
|
|
return output_file, answer_file
|
|
|
|
return output_file
|
|
|
|
# Simple direct LLM query function
|
|
def query_llm(prompt, model='gemma3'):
|
|
"""Query the LLM model directly."""
|
|
try:
|
|
# In a real implementation, this would use the LLM's native API
|
|
from transformers import pipeline
|
|
pipe = pipeline("text-generation", model=model)
|
|
response = pipe(prompt, max_length=1024, temperature=0.7)
|
|
return response[0]["generated_text"].strip()
|
|
except Exception as e:
|
|
logging.error(f"Error querying LLM: {e}")
|
|
# Return error message without hardcoded answers
|
|
return f"Error: {str(e)}"
|
|
|
|
# Hybrid retriever that combines BM25 and vector search
|
|
class HybridRetriever(BaseRetriever):
|
|
"""Hybrid retriever combining BM25 and vector search with configurable weights"""
|
|
|
|
def __init__(self, vector_retriever, bm25_retriever, vector_weight=0.3):
|
|
"""Initialize with separate retrievers and weights"""
|
|
super().__init__()
|
|
# Store retrievers and weights
|
|
self._vector_retriever = vector_retriever
|
|
self._bm25_retriever = bm25_retriever
|
|
self._vector_weight = vector_weight
|
|
self._bm25_weight = 1.0 - vector_weight
|
|
|
|
def get_relevant_documents(self, query):
|
|
"""Get relevant documents using weighted combination of retrievers"""
|
|
try:
|
|
# Get results from both retrievers
|
|
vector_docs = self._vector_retriever.get_relevant_documents(query)
|
|
bm25_docs = self._bm25_retriever.get_relevant_documents(query)
|
|
|
|
# Create dictionary to track unique documents and their scores
|
|
doc_dict = {}
|
|
|
|
# Add vector docs with their weights
|
|
for i, doc in enumerate(vector_docs):
|
|
# Score based on position (inverse rank)
|
|
score = (len(vector_docs) - i) * self._vector_weight
|
|
doc_id = doc.page_content[:50] # Use first 50 chars as a simple ID
|
|
if doc_id in doc_dict:
|
|
doc_dict[doc_id]["score"] += score
|
|
else:
|
|
doc_dict[doc_id] = {"doc": doc, "score": score}
|
|
|
|
# Add BM25 docs with their weights
|
|
for i, doc in enumerate(bm25_docs):
|
|
# Score based on position (inverse rank)
|
|
score = (len(bm25_docs) - i) * self._bm25_weight
|
|
doc_id = doc.page_content[:50] # Use first 50 chars as a simple ID
|
|
if doc_id in doc_dict:
|
|
doc_dict[doc_id]["score"] += score
|
|
else:
|
|
doc_dict[doc_id] = {"doc": doc, "score": score}
|
|
|
|
# Sort by combined score (highest first)
|
|
sorted_docs = sorted(doc_dict.values(), key=lambda x: x["score"], reverse=True)
|
|
|
|
# Return just the document objects
|
|
return [item["doc"] for item in sorted_docs]
|
|
except Exception as e:
|
|
logging.error(f"Error in hybrid retrieval: {e}")
|
|
return []
|
|
|
|
def _get_relevant_documents(self, query):
|
|
"""Required method to satisfy the abstract base class"""
|
|
return self.get_relevant_documents(query)
|
|
|
|
# Semantic Memory class for storing and retrieving memories using embeddings
|
|
class SemanticMemory:
|
|
"""Semantic memory system using embeddings and vector database"""
|
|
|
|
def __init__(self, persist_directory):
|
|
"""Initialize the semantic memory with embeddings"""
|
|
self.embeddings = HuggingFaceEmbeddings(model_name=Config.EMBEDDING_MODEL)
|
|
self.persist_directory = persist_directory
|
|
|
|
# Create directory if it doesn't exist
|
|
os.makedirs(persist_directory, exist_ok=True)
|
|
|
|
# Initialize or load the vector store
|
|
try:
|
|
self.memory_store = Chroma(
|
|
persist_directory=persist_directory,
|
|
embedding_function=self.embeddings
|
|
)
|
|
logging.info(f"Loaded semantic memory from {persist_directory}")
|
|
except Exception as e:
|
|
logging.info(f"Creating new semantic memory: {e}")
|
|
self.memory_store = Chroma(
|
|
persist_directory=persist_directory,
|
|
embedding_function=self.embeddings
|
|
)
|
|
|
|
def add_memory(self, query, answer, metadata=None):
|
|
"""Add a memory (query-answer pair) to the semantic memory"""
|
|
if metadata is None:
|
|
metadata = {"type": "correction", "timestamp": str(datetime.datetime.now())}
|
|
|
|
# Create a document with the query as content and answer in metadata
|
|
document = Document(
|
|
page_content=query,
|
|
metadata={"answer": answer, **metadata}
|
|
)
|
|
|
|
# Add to vector store
|
|
self.memory_store.add_documents([document])
|
|
self.memory_store.persist()
|
|
logging.info(f"Added memory for query: '{query}'")
|
|
return True
|
|
|
|
def retrieve_memory(self, query, similarity_threshold=None):
|
|
"""Retrieve most similar memory to the query"""
|
|
if similarity_threshold is None:
|
|
similarity_threshold = Config.THRESHOLDS['memory_match']
|
|
|
|
try:
|
|
# Search for similar queries
|
|
results = self.memory_store.similarity_search_with_score(query, k=5)
|
|
|
|
if not results:
|
|
return None, None, 0.0
|
|
|
|
# Process all results to find the best match
|
|
best_doc = None
|
|
best_similarity = 0.0
|
|
|
|
for doc, score in results:
|
|
# Convert distance to similarity (Chroma returns distance, not similarity)
|
|
# Using a simple inverse relationship for better cross-language matching
|
|
similarity = 1.0 / (1.0 + score * 2)
|
|
|
|
logging.info(f"Memory candidate: '{doc.page_content}' with similarity: {similarity:.4f}")
|
|
|
|
if similarity > best_similarity:
|
|
best_similarity = similarity
|
|
best_doc = doc
|
|
|
|
if best_similarity >= similarity_threshold:
|
|
logging.info(f"Best memory match: '{best_doc.page_content}' with similarity: {best_similarity:.4f}")
|
|
return best_doc.page_content, best_doc.metadata.get("answer"), best_similarity
|
|
else:
|
|
logging.info(f"Best memory match below threshold ({best_similarity:.4f} < {similarity_threshold})")
|
|
|
|
return None, None, 0.0
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error retrieving memory: {e}")
|
|
return None, None, 0.0
|
|
|
|
def get_all_memories(self):
|
|
"""Get all memories in the system"""
|
|
try:
|
|
return self.memory_store.get()
|
|
except Exception as e:
|
|
logging.error(f"Error getting all memories: {e}")
|
|
return {"ids": [], "documents": [], "metadatas": []}
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Organic Chemistry Web Crawler with Advanced RAG")
|
|
|
|
# Define command modes
|
|
mode_group = parser.add_mutually_exclusive_group(required=True)
|
|
mode_group.add_argument("--query", "-q", help="The chemistry query to search for")
|
|
mode_group.add_argument("--add-correction", action="store_true", help="Add a correction to memory")
|
|
|
|
# Query mode parameters
|
|
parser.add_argument("--engine", choices=["duckduckgo", "arxiv", "combined"], default="combined",
|
|
help="Search engine to use (default: combined)")
|
|
parser.add_argument("--depth", type=int, default=1,
|
|
help="Crawling depth (default: 1)")
|
|
parser.add_argument("--max-pages", type=int, default=20,
|
|
help="Maximum pages to crawl (default: 20)")
|
|
parser.add_argument("--output", help="Output JSON file (default: auto-generated)")
|
|
parser.add_argument("--language", choices=["fa", "en"], default="fa",
|
|
help="Output language (default: fa for Farsi)")
|
|
|
|
# Correction mode parameters
|
|
parser.add_argument("--incorrect", help="The incorrect query to add a correction for")
|
|
parser.add_argument("--correct", help="The correct answer for the query")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Configure crawler
|
|
config = Config()
|
|
config.SEARCH_ENGINE = args.engine
|
|
config.MAX_DEPTH = args.depth
|
|
config.MAX_TOTAL_PAGES = args.max_pages
|
|
config.OUTPUT_LANGUAGE = args.language
|
|
|
|
# Create crawler
|
|
crawler = OrganicChemistryCrawler(config)
|
|
|
|
if args.add_correction:
|
|
# Add a correction to semantic memory
|
|
if not args.incorrect or not args.correct:
|
|
parser.error("Both --incorrect and --correct are required for adding a correction")
|
|
|
|
success = crawler.add_correction(args.incorrect, args.correct)
|
|
if success:
|
|
print(f"Correction added successfully for query: '{args.incorrect}'")
|
|
else:
|
|
print("Failed to add correction")
|
|
else:
|
|
# Process a query
|
|
query = args.query
|
|
print(f"\nProcessing query: {query}")
|
|
|
|
# Get answer using the agent-based approach
|
|
answer = crawler.get_answer(query)
|
|
|
|
# Save the results with the answer
|
|
confidence = 0.0
|
|
if "[Confidence:" in answer:
|
|
match = re.search(r"\[Confidence: ([\d.]+)\]", answer)
|
|
if match:
|
|
confidence = float(match.group(1))
|
|
|
|
output_file, answer_file = crawler.save_results(query, args.output, answer, confidence)
|
|
|
|
print(f"\nProcessing complete! Results saved to: {output_file}")
|
|
print(f"Found information from {len(crawler.crawled_content)} web pages.")
|
|
print(f"\nAnswer:")
|
|
print("=" * 80)
|
|
print(answer)
|
|
print("=" * 80)
|
|
print(f"\nFull answer saved to: {answer_file}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|