Upload files to "/"
This commit is contained in:
parent
dfdeab2855
commit
4c9da22a39
267
Multimodal.py
Normal file
267
Multimodal.py
Normal file
@ -0,0 +1,267 @@
|
|||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
# Clear proxy settings
|
||||||
|
def clear_proxy_settings():
|
||||||
|
for var in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]:
|
||||||
|
if var in os.environ:
|
||||||
|
del os.environ[var]
|
||||||
|
|
||||||
|
clear_proxy_settings()
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import subprocess
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
from langchain_core.vectorstores import InMemoryVectorStore
|
||||||
|
from langchain_ollama import OllamaEmbeddings
|
||||||
|
from langchain_ollama.llms import OllamaLLM
|
||||||
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
|
from unstructured.partition.pdf import partition_pdf
|
||||||
|
from unstructured.partition.utils.constants import PartitionStrategy
|
||||||
|
from search_utils import duckduckgo_search, rank_results
|
||||||
|
|
||||||
|
template = """
|
||||||
|
تو یک دستیار هستی که از یک داده های متنی و تصویری استفاده میکنی تا به سوالات کاربر به زبان فارسی سلیس پاسخ بدی.
|
||||||
|
Question: {question}
|
||||||
|
Context: {context}
|
||||||
|
Answer:
|
||||||
|
"""
|
||||||
|
|
||||||
|
pdfs_directory = 'multi-modal-rag/pdfs/'
|
||||||
|
figures_directory = 'multi-modal-rag/figures/'
|
||||||
|
images_directory = 'multi-modal-rag/images/'
|
||||||
|
videos_directory = 'multi-modal-rag/videos/'
|
||||||
|
audio_directory = 'multi-modal-rag/audio/'
|
||||||
|
frames_directory = 'multi-modal-rag/frames/'
|
||||||
|
|
||||||
|
# Create directories if they don't exist
|
||||||
|
os.makedirs(pdfs_directory, exist_ok=True)
|
||||||
|
os.makedirs(figures_directory, exist_ok=True)
|
||||||
|
os.makedirs(images_directory, exist_ok=True)
|
||||||
|
os.makedirs(videos_directory, exist_ok=True)
|
||||||
|
os.makedirs(audio_directory, exist_ok=True)
|
||||||
|
os.makedirs(frames_directory, exist_ok=True)
|
||||||
|
|
||||||
|
embeddings = OllamaEmbeddings(model="llama3.2")
|
||||||
|
vector_store = InMemoryVectorStore(embeddings)
|
||||||
|
|
||||||
|
model = OllamaLLM(model="gemma3")
|
||||||
|
|
||||||
|
def upload_pdf(file):
|
||||||
|
with open(pdfs_directory + file.name, "wb") as f:
|
||||||
|
f.write(file.getbuffer())
|
||||||
|
|
||||||
|
def upload_image(file):
|
||||||
|
with open(images_directory + file.name, "wb") as f:
|
||||||
|
f.write(file.getbuffer())
|
||||||
|
return images_directory + file.name
|
||||||
|
|
||||||
|
def upload_video(file):
|
||||||
|
file_path = videos_directory + file.name
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(file.getbuffer())
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
def upload_audio(file):
|
||||||
|
file_path = audio_directory + file.name
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(file.getbuffer())
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
def load_pdf(file_path):
|
||||||
|
elements = partition_pdf(
|
||||||
|
file_path,
|
||||||
|
strategy=PartitionStrategy.HI_RES,
|
||||||
|
extract_image_block_types=["Image", "Table"],
|
||||||
|
extract_image_block_output_dir=figures_directory
|
||||||
|
)
|
||||||
|
|
||||||
|
text_elements = [element.text for element in elements if element.category not in ["Image", "Table"]]
|
||||||
|
|
||||||
|
for file in os.listdir(figures_directory):
|
||||||
|
extracted_text = extract_text(figures_directory + file)
|
||||||
|
text_elements.append(extracted_text)
|
||||||
|
|
||||||
|
return "\n\n".join(text_elements)
|
||||||
|
|
||||||
|
def extract_frames(video_path, num_frames=5):
|
||||||
|
"""Extract frames from video file and save them to frames directory"""
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||||
|
base_name = os.path.basename(video_path).split('.')[0]
|
||||||
|
frame_paths = []
|
||||||
|
|
||||||
|
# Extract frames using ffmpeg
|
||||||
|
for i in range(num_frames):
|
||||||
|
frame_path = f"{frames_directory}{base_name}_{timestamp}_{i}.jpg"
|
||||||
|
cmd = [
|
||||||
|
'ffmpeg', '-i', video_path,
|
||||||
|
'-ss', str(i * (1/num_frames)), '-vframes', '1',
|
||||||
|
'-q:v', '2', frame_path, '-y'
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||||
|
frame_paths.append(frame_path)
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
st.warning(f"Failed to extract frame {i} from video")
|
||||||
|
|
||||||
|
return frame_paths
|
||||||
|
|
||||||
|
def process_audio(audio_path):
|
||||||
|
"""Process audio file using the model"""
|
||||||
|
audio_description = model.invoke(
|
||||||
|
f"Describe what you hear in this audio file: {os.path.basename(audio_path)}"
|
||||||
|
)
|
||||||
|
return f"Audio file: {os.path.basename(audio_path)}. Description: {audio_description}"
|
||||||
|
|
||||||
|
def extract_text(file_path):
|
||||||
|
model_with_image_context = model.bind(images=[file_path])
|
||||||
|
return model_with_image_context.invoke("Tell me what do you see in this picture.")
|
||||||
|
|
||||||
|
def split_text(text):
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=1000,
|
||||||
|
chunk_overlap=200,
|
||||||
|
add_start_index=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return text_splitter.split_text(text)
|
||||||
|
|
||||||
|
def index_docs(texts):
|
||||||
|
vector_store.add_texts(texts)
|
||||||
|
|
||||||
|
def retrieve_docs(query):
|
||||||
|
return vector_store.similarity_search(query)
|
||||||
|
|
||||||
|
def answer_question(question, documents):
|
||||||
|
local_context = "\n\n".join([doc.page_content for doc in documents])
|
||||||
|
prompt = ChatPromptTemplate.from_template(template)
|
||||||
|
chain = prompt | model
|
||||||
|
return chain.invoke({"question": question, "context": local_context})
|
||||||
|
|
||||||
|
# Sidebar for upload options
|
||||||
|
st.sidebar.title("Upload Documents")
|
||||||
|
upload_option = st.sidebar.radio("Choose upload type:", ["PDF", "Image", "Video", "Audio", "Search"])
|
||||||
|
|
||||||
|
if upload_option == "Search":
|
||||||
|
st.title("Web Search with BM25 Ranking")
|
||||||
|
search_query = st.text_input("Enter your search query:")
|
||||||
|
|
||||||
|
if search_query:
|
||||||
|
with st.spinner("Searching and ranking results..."):
|
||||||
|
# Get search results
|
||||||
|
search_results = duckduckgo_search(search_query, max_results=10)
|
||||||
|
|
||||||
|
if search_results:
|
||||||
|
# Rank results using BM25
|
||||||
|
ranked_results = rank_results(search_query, search_results)
|
||||||
|
|
||||||
|
# Display results
|
||||||
|
st.subheader("Ranked Search Results")
|
||||||
|
for i, result in enumerate(ranked_results):
|
||||||
|
with st.expander(f"{i+1}. {result.title}"):
|
||||||
|
st.write(f"**Snippet:** {result.snippet}")
|
||||||
|
st.write(f"**URL:** {result.url}")
|
||||||
|
|
||||||
|
# Option to ask about search results
|
||||||
|
st.subheader("Ask about these results")
|
||||||
|
question = st.text_input("Enter your question about the search results:")
|
||||||
|
|
||||||
|
if question:
|
||||||
|
# Prepare context from top results
|
||||||
|
context = "\n\n".join([f"Title: {r.title}\nSnippet: {r.snippet}" for r in ranked_results[:3]])
|
||||||
|
|
||||||
|
# Use the model to answer
|
||||||
|
prompt = ChatPromptTemplate.from_template(template)
|
||||||
|
chain = prompt | model
|
||||||
|
|
||||||
|
with st.spinner("Generating answer..."):
|
||||||
|
response = chain.invoke({"question": question, "context": context})
|
||||||
|
st.markdown("### Answer")
|
||||||
|
st.write(response.content)
|
||||||
|
else:
|
||||||
|
st.warning("No search results found")
|
||||||
|
|
||||||
|
elif upload_option == "PDF":
|
||||||
|
uploaded_file = st.file_uploader(
|
||||||
|
"Upload PDF",
|
||||||
|
type="pdf",
|
||||||
|
accept_multiple_files=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if uploaded_file:
|
||||||
|
upload_pdf(uploaded_file)
|
||||||
|
with st.spinner("Processing PDF..."):
|
||||||
|
text = load_pdf(pdfs_directory + uploaded_file.name)
|
||||||
|
chunked_texts = split_text(text)
|
||||||
|
index_docs(chunked_texts)
|
||||||
|
st.success("PDF processed successfully!")
|
||||||
|
|
||||||
|
elif upload_option == "Image":
|
||||||
|
uploaded_image = st.file_uploader(
|
||||||
|
"Upload Image",
|
||||||
|
type=["jpg", "jpeg", "png"],
|
||||||
|
accept_multiple_files=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if uploaded_image:
|
||||||
|
image_path = upload_image(uploaded_image)
|
||||||
|
st.image(image_path, caption="Uploaded Image", use_column_width=True)
|
||||||
|
with st.spinner("Processing image..."):
|
||||||
|
image_description = extract_text(image_path)
|
||||||
|
index_docs([image_description])
|
||||||
|
st.success("Image processed and added to knowledge base")
|
||||||
|
|
||||||
|
elif upload_option == "Video":
|
||||||
|
uploaded_video = st.file_uploader(
|
||||||
|
"Upload Video",
|
||||||
|
type=["mp4", "avi", "mov", "mkv"],
|
||||||
|
accept_multiple_files=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if uploaded_video:
|
||||||
|
video_path = upload_video(uploaded_video)
|
||||||
|
st.video(video_path)
|
||||||
|
|
||||||
|
with st.spinner("Processing video frames..."):
|
||||||
|
frame_paths = extract_frames(video_path)
|
||||||
|
video_descriptions = []
|
||||||
|
|
||||||
|
for frame_path in frame_paths:
|
||||||
|
st.image(frame_path, caption=f"Frame from video", width=200)
|
||||||
|
frame_description = extract_text(frame_path)
|
||||||
|
video_descriptions.append(frame_description)
|
||||||
|
|
||||||
|
# Add a combined description
|
||||||
|
combined_description = f"Video file: {uploaded_video.name}. Content description: " + " ".join(video_descriptions)
|
||||||
|
index_docs([combined_description])
|
||||||
|
st.success("Video processed and added to knowledge base")
|
||||||
|
|
||||||
|
else: # Audio option
|
||||||
|
uploaded_audio = st.file_uploader(
|
||||||
|
"Upload Audio",
|
||||||
|
type=["mp3", "wav", "ogg"],
|
||||||
|
accept_multiple_files=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if uploaded_audio:
|
||||||
|
audio_path = upload_audio(uploaded_audio)
|
||||||
|
st.audio(audio_path)
|
||||||
|
|
||||||
|
with st.spinner("Processing audio..."):
|
||||||
|
# For audio, we'll use the model directly without visual context
|
||||||
|
audio_description = process_audio(audio_path)
|
||||||
|
index_docs([audio_description])
|
||||||
|
st.success("Audio processed and added to knowledge base")
|
||||||
|
|
||||||
|
# Chat interface
|
||||||
|
question = st.chat_input()
|
||||||
|
|
||||||
|
if question:
|
||||||
|
st.chat_message("user").write(question)
|
||||||
|
related_documents = retrieve_docs(question)
|
||||||
|
answer = answer_question(question, related_documents)
|
||||||
|
st.chat_message("assistant").write(answer)
|
||||||
@ -1,340 +1,460 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import pickle
|
import re
|
||||||
import json
|
import json
|
||||||
import nltk
|
import ssl
|
||||||
|
import argparse
|
||||||
import requests
|
import requests
|
||||||
import time
|
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
from langchain_community.document_loaders import PDFPlumberLoader, WebBaseLoader
|
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
from langchain_community.vectorstores import Chroma
|
||||||
from langchain_community.retrievers import BM25Retriever
|
from langchain_core.documents import Document
|
||||||
|
import traceback
|
||||||
|
|
||||||
try:
|
# Disable SSL warnings and proxy settings
|
||||||
nltk.data.find('tokenizers/punkt')
|
ssl._create_default_https_context = ssl._create_unverified_context
|
||||||
except LookupError:
|
requests.packages.urllib3.disable_warnings()
|
||||||
nltk.download('punkt')
|
|
||||||
|
|
||||||
class ModularRAG:
|
def clear_proxy_settings():
|
||||||
def __init__(self):
|
"""Remove proxy environment variables that might cause connection issues."""
|
||||||
self.storage_path = "./rag_data"
|
for var in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]:
|
||||||
|
if var in os.environ:
|
||||||
|
print(f"Removing proxy env var: {var}")
|
||||||
|
del os.environ[var]
|
||||||
|
|
||||||
if not os.path.exists(self.storage_path):
|
# Run at module load time
|
||||||
os.makedirs(self.storage_path)
|
clear_proxy_settings()
|
||||||
os.makedirs(os.path.join(self.storage_path, "documents"))
|
|
||||||
os.makedirs(os.path.join(self.storage_path, "web_results"))
|
|
||||||
|
|
||||||
self.documents = []
|
# Configuration
|
||||||
self.web_results = []
|
DOCUMENT_PATHS = [
|
||||||
|
r'doc1.txt',
|
||||||
|
r'doc2.txt',
|
||||||
|
r'doc3.txt',
|
||||||
|
r'doc4.txt',
|
||||||
|
r'doc5.txt',
|
||||||
|
r'doc6.txt'
|
||||||
|
]
|
||||||
|
EMBEDDING_MODEL = 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2'
|
||||||
|
LLM_MODEL = 'gemma3'
|
||||||
|
CHUNK_SIZE = 1000
|
||||||
|
OVERLAP = 200
|
||||||
|
CHROMA_PERSIST_DIR = 'chroma_db'
|
||||||
|
|
||||||
# Web crawler settings
|
# Confidence thresholds
|
||||||
self.headers = {
|
THRESHOLDS = {
|
||||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
'direct_answer': 0.7,
|
||||||
|
'rag_confidence': 0.6,
|
||||||
|
'web_search': 0.5
|
||||||
|
}
|
||||||
|
|
||||||
|
def query_llm(prompt, model='gemma3'):
|
||||||
|
"""Query the LLM model directly using Ollama API."""
|
||||||
|
try:
|
||||||
|
ollama_endpoint = "http://localhost:11434/api/generate"
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"stream": False
|
||||||
}
|
}
|
||||||
self.num_search_results = 10
|
response = requests.post(ollama_endpoint, json=payload)
|
||||||
self.max_depth = 2
|
|
||||||
self.max_links_per_page = 5
|
|
||||||
self.max_paragraphs = 5
|
|
||||||
|
|
||||||
self._load_saved_data()
|
if response.status_code == 200:
|
||||||
|
result = response.json()
|
||||||
|
return result.get('response', '')
|
||||||
|
else:
|
||||||
|
print(f"Ollama API error: {response.status_code}")
|
||||||
|
return f"Error calling Ollama API: {response.status_code}"
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error querying LLM: {e}")
|
||||||
|
return f"Error: {str(e)}"
|
||||||
|
|
||||||
def _load_saved_data(self):
|
class BM25Retriever:
|
||||||
doc_path = os.path.join(self.storage_path, "documents", "docs.pkl")
|
"""BM25 retriever implementation for text similarity search"""
|
||||||
web_path = os.path.join(self.storage_path, "web_results", "web.json")
|
|
||||||
|
|
||||||
if os.path.exists(doc_path):
|
@classmethod
|
||||||
try:
|
def from_documents(cls, documents):
|
||||||
with open(doc_path, 'rb') as f:
|
"""Create a BM25 retriever from documents"""
|
||||||
self.documents = pickle.load(f)
|
retriever = cls()
|
||||||
except Exception as e:
|
retriever.documents = documents
|
||||||
print(f"خطا در بارگیری اسناد: {e}")
|
retriever.k = 4
|
||||||
|
return retriever
|
||||||
|
|
||||||
if os.path.exists(web_path):
|
def get_relevant_documents(self, query):
|
||||||
try:
|
"""Get relevant documents using BM25 algorithm"""
|
||||||
with open(web_path, 'r', encoding='utf-8') as f:
|
# Simple BM25-like implementation
|
||||||
self.web_results = json.load(f)
|
scores = []
|
||||||
except Exception as e:
|
query_terms = set(re.findall(r'\b\w+\b', query.lower()))
|
||||||
print(f"خطا در بارگیری نتایج وب: {e}")
|
|
||||||
|
|
||||||
def _save_documents(self):
|
for doc in self.documents:
|
||||||
doc_path = os.path.join(self.storage_path, "documents", "docs.pkl")
|
doc_terms = set(re.findall(r'\b\w+\b', doc.page_content.lower()))
|
||||||
|
# Calculate term overlap as a simple approximation of BM25
|
||||||
|
overlap = len(query_terms.intersection(doc_terms))
|
||||||
|
scores.append((doc, overlap))
|
||||||
|
|
||||||
|
# Sort by score and return top k
|
||||||
|
sorted_docs = [doc for doc, score in sorted(scores, key=lambda x: x[1], reverse=True)]
|
||||||
|
return sorted_docs[:self.k]
|
||||||
|
|
||||||
|
class HybridRetriever:
|
||||||
|
"""Hybrid retriever combining BM25 and vector search with configurable weights"""
|
||||||
|
|
||||||
|
def __init__(self, vector_retriever, bm25_retriever, vector_weight=0.3):
|
||||||
|
"""Initialize with separate retrievers and weights"""
|
||||||
|
self._vector_retriever = vector_retriever
|
||||||
|
self._bm25_retriever = bm25_retriever
|
||||||
|
self._vector_weight = vector_weight
|
||||||
|
self._bm25_weight = 1.0 - vector_weight
|
||||||
|
|
||||||
|
def get_relevant_documents(self, query):
|
||||||
|
"""Get relevant documents using weighted combination of retrievers"""
|
||||||
try:
|
try:
|
||||||
with open(doc_path, 'wb') as f:
|
# Get results from both retrievers
|
||||||
pickle.dump(self.documents, f)
|
vector_docs = self._vector_retriever.get_relevant_documents(query)
|
||||||
|
bm25_docs = self._bm25_retriever.get_relevant_documents(query)
|
||||||
|
|
||||||
|
# Create dictionary to track unique documents and their scores
|
||||||
|
doc_dict = {}
|
||||||
|
|
||||||
|
# Add vector docs with their weights
|
||||||
|
for i, doc in enumerate(vector_docs):
|
||||||
|
# Score based on position (inverse rank)
|
||||||
|
score = (len(vector_docs) - i) * self._vector_weight
|
||||||
|
doc_id = doc.page_content[:50] # Use first 50 chars as a simple ID
|
||||||
|
if doc_id in doc_dict:
|
||||||
|
doc_dict[doc_id]["score"] += score
|
||||||
|
else:
|
||||||
|
doc_dict[doc_id] = {"doc": doc, "score": score}
|
||||||
|
|
||||||
|
# Add BM25 docs with their weights
|
||||||
|
for i, doc in enumerate(bm25_docs):
|
||||||
|
# Score based on position (inverse rank)
|
||||||
|
score = (len(bm25_docs) - i) * self._bm25_weight
|
||||||
|
doc_id = doc.page_content[:50] # Use first 50 chars as a simple ID
|
||||||
|
if doc_id in doc_dict:
|
||||||
|
doc_dict[doc_id]["score"] += score
|
||||||
|
else:
|
||||||
|
doc_dict[doc_id] = {"doc": doc, "score": score}
|
||||||
|
|
||||||
|
# Sort by combined score (highest first)
|
||||||
|
sorted_docs = sorted(doc_dict.values(), key=lambda x: x["score"], reverse=True)
|
||||||
|
|
||||||
|
# Return just the document objects
|
||||||
|
return [item["doc"] for item in sorted_docs]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"خطا در ذخیرهسازی اسناد: {e}")
|
print(f"Error in hybrid retrieval: {e}")
|
||||||
|
|
||||||
def _save_web_results(self):
|
|
||||||
web_path = os.path.join(self.storage_path, "web_results", "web.json")
|
|
||||||
try:
|
|
||||||
with open(web_path, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(self.web_results, f, ensure_ascii=False, indent=2)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"خطا در ذخیرهسازی نتایج وب: {e}")
|
|
||||||
|
|
||||||
def load_pdf(self, file_path):
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
raise FileNotFoundError(f"فایل یافت نشد: {file_path}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
loader = PDFPlumberLoader(file_path)
|
|
||||||
documents = loader.load()
|
|
||||||
|
|
||||||
if documents:
|
|
||||||
text_splitter = RecursiveCharacterTextSplitter(
|
|
||||||
chunk_size=1000,
|
|
||||||
chunk_overlap=200,
|
|
||||||
add_start_index=True
|
|
||||||
)
|
|
||||||
chunked_docs = text_splitter.split_documents(documents)
|
|
||||||
|
|
||||||
self.documents.extend(chunked_docs)
|
|
||||||
self._save_documents()
|
|
||||||
return len(chunked_docs)
|
|
||||||
return 0
|
|
||||||
except Exception as e:
|
|
||||||
raise Exception(f"خطا در بارگیری PDF: {e}")
|
|
||||||
|
|
||||||
def search_duckduckgo(self, query, num_results=None):
|
|
||||||
if num_results is None:
|
|
||||||
num_results = self.num_search_results
|
|
||||||
|
|
||||||
try:
|
|
||||||
search_url = f"https://html.duckduckgo.com/html/?q={quote(query)}"
|
|
||||||
response = requests.get(search_url, headers=self.headers, timeout=10)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
print(f"خطا در جستجوی وب: HTTP {response.status_code}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
soup = BeautifulSoup(response.text, 'html.parser')
|
|
||||||
results = []
|
|
||||||
|
|
||||||
for element in soup.select('.result__url, .result__a'):
|
|
||||||
href = element.get('href') if 'href' in element.attrs else None
|
|
||||||
|
|
||||||
if href and not href.startswith('/') and (href.startswith('http://') or href.startswith('https://')):
|
|
||||||
results.append(href)
|
|
||||||
elif not href and element.find('a') and 'href' in element.find('a').attrs:
|
|
||||||
href = element.find('a')['href']
|
|
||||||
if href and not href.startswith('/'):
|
|
||||||
results.append(href)
|
|
||||||
|
|
||||||
unique_results = list(set(results))
|
|
||||||
return unique_results[:num_results]
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"خطا در جستجوی DuckDuckGo: {e}")
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def crawl_page(self, url, depth=0):
|
class AgenticQASystem:
|
||||||
if depth > self.max_depth:
|
"""QA system implementing the specified architecture"""
|
||||||
return None, []
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the QA system with retrievers"""
|
||||||
|
# Load embeddings
|
||||||
|
self.embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
|
||||||
|
# Load documents and retrievers
|
||||||
|
self.documents = self.load_documents()
|
||||||
|
self.retriever = self.initialize_retriever()
|
||||||
|
|
||||||
|
def load_documents(self):
|
||||||
|
"""Load documents from configured paths with sliding window chunking"""
|
||||||
|
print("Loading documents...")
|
||||||
|
docs = []
|
||||||
|
for path in DOCUMENT_PATHS:
|
||||||
|
try:
|
||||||
|
with open(path, 'r', encoding='utf-8') as f:
|
||||||
|
text = re.sub(r'\s+', ' ', f.read()).strip()
|
||||||
|
# Sliding window chunking
|
||||||
|
chunks = [text[i:i+CHUNK_SIZE] for i in range(0, len(text), CHUNK_SIZE - OVERLAP)]
|
||||||
|
for chunk in chunks:
|
||||||
|
docs.append(Document(
|
||||||
|
page_content=chunk,
|
||||||
|
metadata={"source": os.path.basename(path)}
|
||||||
|
))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading document {path}: {e}")
|
||||||
|
print(f"Loaded {len(docs)} document chunks")
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def initialize_retriever(self):
|
||||||
|
"""Initialize the hybrid retriever with BM25 and direct Chroma queries"""
|
||||||
|
if not self.documents:
|
||||||
|
print("No documents loaded, retriever initialization failed")
|
||||||
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.get(url, headers=self.headers, timeout=10)
|
# Create BM25 retriever
|
||||||
response.raise_for_status()
|
bm25_retriever = BM25Retriever.from_documents(self.documents)
|
||||||
|
bm25_retriever.k = 4 # Top k results to retrieve
|
||||||
|
|
||||||
soup = BeautifulSoup(response.text, 'html.parser')
|
# Initialize vector store with KNN search
|
||||||
|
import shutil
|
||||||
|
if os.path.exists(CHROMA_PERSIST_DIR):
|
||||||
|
print(f"Removing existing Chroma DB to prevent dimension mismatch")
|
||||||
|
shutil.rmtree(CHROMA_PERSIST_DIR)
|
||||||
|
|
||||||
title = soup.title.string if soup.title else "بدون عنوان"
|
# Create vector store directly from Chroma
|
||||||
|
print("Creating vector store...")
|
||||||
|
vector_store = Chroma.from_documents(
|
||||||
|
documents=self.documents,
|
||||||
|
embedding=self.embeddings,
|
||||||
|
persist_directory=CHROMA_PERSIST_DIR
|
||||||
|
)
|
||||||
|
|
||||||
paragraphs = []
|
vector_retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 4})
|
||||||
for p in soup.find_all('p'):
|
print(f"Vector retriever created: {type(vector_retriever)}")
|
||||||
text = p.get_text(strip=True)
|
|
||||||
if len(text) > 50:
|
|
||||||
paragraphs.append(text)
|
|
||||||
if len(paragraphs) >= self.max_paragraphs:
|
|
||||||
break
|
|
||||||
|
|
||||||
links = []
|
# Create hybrid retriever - BM25 (70%) and Vector (30%)
|
||||||
for a in soup.find_all('a', href=True):
|
print("Creating hybrid retriever")
|
||||||
href = a['href']
|
hybrid_retriever = HybridRetriever(vector_retriever, bm25_retriever, vector_weight=0.3)
|
||||||
if href.startswith('http') and href != url:
|
print("Hybrid retriever initialized successfully")
|
||||||
links.append(href)
|
return hybrid_retriever
|
||||||
if len(links) >= self.max_links_per_page:
|
|
||||||
break
|
|
||||||
|
|
||||||
content = {
|
|
||||||
"url": url,
|
|
||||||
"title": title,
|
|
||||||
"paragraphs": paragraphs
|
|
||||||
}
|
|
||||||
|
|
||||||
return content, links
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"خطا در خزش صفحه {url}: {e}")
|
print(f"Error initializing retriever: {e}")
|
||||||
return None, []
|
traceback.print_exc()
|
||||||
|
return None
|
||||||
|
|
||||||
def crawl_website(self, start_url, max_pages=10):
|
def estimate_confidence(self, text, query, context=None):
|
||||||
visited = set()
|
"""Estimate confidence of response"""
|
||||||
to_visit = [start_url]
|
# Start with baseline confidence
|
||||||
contents = []
|
confidence = 0.5
|
||||||
|
|
||||||
while to_visit and len(visited) < max_pages:
|
# Check for uncertainty markers
|
||||||
current_url = to_visit.pop(0)
|
uncertainty_phrases = [
|
||||||
|
"نمیدانم", "مطمئن نیستم", "ممکن است", "شاید", "احتمالاً",
|
||||||
|
"فکر میکنم", "به نظر میرسد"
|
||||||
|
]
|
||||||
|
|
||||||
if current_url in visited:
|
if any(phrase in text.lower() for phrase in uncertainty_phrases):
|
||||||
|
confidence -= 0.2
|
||||||
|
|
||||||
|
# Check for question relevance
|
||||||
|
query_words = set(re.findall(r'\b\w+\b', query.lower()))
|
||||||
|
text_words = set(re.findall(r'\b\w+\b', text.lower()))
|
||||||
|
|
||||||
|
# Calculate overlap between query and response
|
||||||
|
if query_words:
|
||||||
|
overlap_ratio = len(query_words.intersection(text_words)) / len(query_words)
|
||||||
|
if overlap_ratio > 0.5:
|
||||||
|
confidence += 0.2
|
||||||
|
elif overlap_ratio < 0.2:
|
||||||
|
confidence -= 0.2
|
||||||
|
|
||||||
|
# If context provided, check context relevance
|
||||||
|
if context:
|
||||||
|
context_words = set(re.findall(r'\b\w+\b', context.lower()))
|
||||||
|
if context_words:
|
||||||
|
context_overlap = len(context_words.intersection(text_words)) / len(context_words)
|
||||||
|
if context_overlap > 0.3:
|
||||||
|
confidence += 0.2
|
||||||
|
else:
|
||||||
|
confidence -= 0.1
|
||||||
|
|
||||||
|
# Ensure confidence is within [0,1]
|
||||||
|
return max(0.0, min(1.0, confidence))
|
||||||
|
|
||||||
|
def check_direct_knowledge(self, query):
|
||||||
|
"""Check if the LLM can answer directly from its knowledge"""
|
||||||
|
print("Checking LLM's direct knowledge...")
|
||||||
|
prompt = f"""به این سوال با استفاده از دانش خود پاسخ دهید. فقط به زبان فارسی پاسخ دهید.
|
||||||
|
|
||||||
|
سوال: {query}
|
||||||
|
|
||||||
|
پاسخ فارسی:"""
|
||||||
|
|
||||||
|
response = query_llm(prompt, model=LLM_MODEL)
|
||||||
|
confidence = self.estimate_confidence(response, query)
|
||||||
|
print(f"LLM direct knowledge confidence: {confidence:.2f}")
|
||||||
|
|
||||||
|
return response, confidence
|
||||||
|
|
||||||
|
def rag_query(self, query):
|
||||||
|
"""Use RAG to retrieve and generate answer"""
|
||||||
|
if not self.retriever:
|
||||||
|
print("Retriever not initialized, skipping RAG")
|
||||||
|
return None, 0.0
|
||||||
|
|
||||||
|
print("Retrieving documents for RAG...")
|
||||||
|
# Retrieve relevant documents
|
||||||
|
docs = self.retriever.get_relevant_documents(query)
|
||||||
|
|
||||||
|
if not docs:
|
||||||
|
print("No relevant documents found")
|
||||||
|
return None, 0.0
|
||||||
|
|
||||||
|
print(f"Retrieved {len(docs)} relevant documents")
|
||||||
|
|
||||||
|
# Prepare context
|
||||||
|
context = "\n\n".join([doc.page_content for doc in docs])
|
||||||
|
sources = [doc.metadata.get("source", "Unknown") for doc in docs]
|
||||||
|
|
||||||
|
# Query LLM with context
|
||||||
|
prompt = f"""با توجه به اطلاعات زیر، به سوال پاسخ دهید. فقط به زبان فارسی پاسخ دهید.
|
||||||
|
|
||||||
|
اطلاعات:
|
||||||
|
{context}
|
||||||
|
|
||||||
|
سوال: {query}
|
||||||
|
|
||||||
|
پاسخ فارسی:"""
|
||||||
|
|
||||||
|
response = query_llm(prompt, model=LLM_MODEL)
|
||||||
|
confidence = self.estimate_confidence(response, query, context)
|
||||||
|
print(f"RAG confidence: {confidence:.2f}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"response": response,
|
||||||
|
"confidence": confidence,
|
||||||
|
"sources": list(set(sources))
|
||||||
|
}, confidence
|
||||||
|
|
||||||
|
def web_search(self, query):
|
||||||
|
"""Search the web for an answer"""
|
||||||
|
print("Searching web for answer...")
|
||||||
|
# Search DuckDuckGo
|
||||||
|
search_url = f"https://html.duckduckgo.com/html/?q={quote(query)}"
|
||||||
|
response = requests.get(search_url, verify=False, timeout=10)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
print(f"Error searching web: HTTP {response.status_code}")
|
||||||
|
return None, 0.0
|
||||||
|
|
||||||
|
# Parse results
|
||||||
|
soup = BeautifulSoup(response.text, 'html.parser')
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for element in soup.select('.result__url, .result__a')[:4]:
|
||||||
|
href = element.get('href') if 'href' in element.attrs else None
|
||||||
|
|
||||||
|
if href and not href.startswith('/') and (href.startswith('http://') or href.startswith('https://')):
|
||||||
|
results.append(href)
|
||||||
|
elif not href and element.find('a') and 'href' in element.find('a').attrs:
|
||||||
|
href = element.find('a')['href']
|
||||||
|
if href and not href.startswith('/'):
|
||||||
|
results.append(href)
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
print("No web results found")
|
||||||
|
return None, 0.0
|
||||||
|
|
||||||
|
# Crawl top results
|
||||||
|
web_content = []
|
||||||
|
for url in results[:3]:
|
||||||
|
try:
|
||||||
|
headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
|
||||||
|
page = requests.get(url, headers=headers, timeout=10, verify=False)
|
||||||
|
page.raise_for_status()
|
||||||
|
|
||||||
|
soup = BeautifulSoup(page.text, 'html.parser')
|
||||||
|
|
||||||
|
# Remove non-content elements
|
||||||
|
for tag in ['script', 'style', 'nav', 'footer', 'header']:
|
||||||
|
for element in soup.find_all(tag):
|
||||||
|
element.decompose()
|
||||||
|
|
||||||
|
# Get paragraphs
|
||||||
|
paragraphs = [p.get_text(strip=True) for p in soup.find_all('p')
|
||||||
|
if len(p.get_text(strip=True)) > 20]
|
||||||
|
|
||||||
|
if paragraphs:
|
||||||
|
web_content.append(f"[Source: {url}] " + " ".join(paragraphs[:5]))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error crawling {url}: {e}")
|
||||||
|
|
||||||
|
if not web_content:
|
||||||
|
print("No useful content found from web results")
|
||||||
|
return None, 0.0
|
||||||
|
|
||||||
|
# Query LLM with web content
|
||||||
|
context = "\n\n".join(web_content)
|
||||||
|
prompt = f"""با توجه به اطلاعات زیر که از وب بدست آمده، به سوال پاسخ دهید. فقط به زبان فارسی پاسخ دهید.
|
||||||
|
|
||||||
|
اطلاعات:
|
||||||
|
{context}
|
||||||
|
|
||||||
|
سوال: {query}
|
||||||
|
|
||||||
|
پاسخ فارسی:"""
|
||||||
|
|
||||||
|
response = query_llm(prompt, model=LLM_MODEL)
|
||||||
|
confidence = self.estimate_confidence(response, query, context)
|
||||||
|
print(f"Web search confidence: {confidence:.2f}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"response": response,
|
||||||
|
"confidence": confidence,
|
||||||
|
"sources": results[:3]
|
||||||
|
}, confidence
|
||||||
|
|
||||||
|
def get_answer(self, query):
|
||||||
|
"""Main method to get an answer following the specified architecture"""
|
||||||
|
print(f"Processing query: {query}")
|
||||||
|
|
||||||
|
# STEP 1: Try direct LLM knowledge
|
||||||
|
direct_response, direct_confidence = self.check_direct_knowledge(query)
|
||||||
|
|
||||||
|
if direct_confidence >= THRESHOLDS['direct_answer']:
|
||||||
|
print("Using direct LLM knowledge (high confidence)")
|
||||||
|
return f"{direct_response}\n\n[Source: LLM Knowledge, Confidence: {direct_confidence:.2f}]"
|
||||||
|
|
||||||
|
# STEP 2: Try RAG with local documents
|
||||||
|
rag_result, rag_confidence = self.rag_query(query)
|
||||||
|
|
||||||
|
if rag_result and rag_confidence >= THRESHOLDS['rag_confidence']:
|
||||||
|
print("Using RAG response (sufficient confidence)")
|
||||||
|
sources_text = ", ".join(rag_result["sources"][:3])
|
||||||
|
return f"{rag_result['response']}\n\n[Source: Local Documents, Confidence: {rag_confidence:.2f}, Sources: {sources_text}]"
|
||||||
|
|
||||||
|
# STEP 3: Try web search
|
||||||
|
web_result, web_confidence = self.web_search(query)
|
||||||
|
|
||||||
|
if web_result and web_confidence >= THRESHOLDS['web_search']:
|
||||||
|
print("Using web search response (sufficient confidence)")
|
||||||
|
sources_text = ", ".join(web_result["sources"])
|
||||||
|
return f"{web_result['response']}\n\n[Source: Web Search, Confidence: {web_confidence:.2f}, Sources: {sources_text}]"
|
||||||
|
|
||||||
|
# STEP 4: Fall back to direct response with warning
|
||||||
|
print("No high-confidence source found, using direct response with warning")
|
||||||
|
return f"{direct_response}\n\n[Warning: Low confidence ({direct_confidence:.2f}). Please verify information.]"
|
||||||
|
|
||||||
|
# Simple API functions
|
||||||
|
def get_answer(query):
|
||||||
|
"""Get an answer for a query"""
|
||||||
|
system = AgenticQASystem()
|
||||||
|
return system.get_answer(query)
|
||||||
|
|
||||||
|
# Main entry point
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="QA System")
|
||||||
|
|
||||||
|
mode_group = parser.add_mutually_exclusive_group(required=True)
|
||||||
|
mode_group.add_argument("--query", "-q", help="Query to answer")
|
||||||
|
mode_group.add_argument("--interactive", "-i", action="store_true", help="Run in interactive chat mode")
|
||||||
|
mode_group.add_argument("--test", "-t", action="store_true", help="Run tests")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.interactive:
|
||||||
|
# Simple interactive mode without memory
|
||||||
|
qa_system = AgenticQASystem()
|
||||||
|
print("=== QA System ===")
|
||||||
|
print("Type 'exit' or 'quit' to end")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
user_input = input("\nYou: ")
|
||||||
|
if not user_input.strip():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
content, links = self.crawl_page(current_url)
|
if user_input.lower() in ['exit', 'quit', 'خروج']:
|
||||||
|
break
|
||||||
|
|
||||||
visited.add(current_url)
|
response = qa_system.get_answer(user_input)
|
||||||
|
print(f"\nBot: {response}")
|
||||||
if content and content["paragraphs"]:
|
elif args.query:
|
||||||
contents.append(content)
|
qa_system = AgenticQASystem()
|
||||||
|
print(qa_system.get_answer(args.query))
|
||||||
for link in links:
|
elif args.test:
|
||||||
if link not in visited and link not in to_visit:
|
print("Running tests...")
|
||||||
to_visit.append(link)
|
|
||||||
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
return contents
|
|
||||||
|
|
||||||
def crawl_web(self, query):
|
|
||||||
urls = self.search_duckduckgo(query)
|
|
||||||
|
|
||||||
if not urls:
|
|
||||||
print("هیچ نتیجهای یافت نشد.")
|
|
||||||
return []
|
|
||||||
|
|
||||||
all_results = []
|
|
||||||
for url in urls[:3]: # Limit to first 3 URLs for efficiency
|
|
||||||
content, links = self.crawl_page(url)
|
|
||||||
if content and content["paragraphs"]:
|
|
||||||
all_results.append(content)
|
|
||||||
|
|
||||||
# Follow links from the main page (recursive crawling)
|
|
||||||
for link in links[:2]: # Limit to first 2 links
|
|
||||||
sub_content, _ = self.crawl_page(link, depth=1)
|
|
||||||
if sub_content and sub_content["paragraphs"]:
|
|
||||||
all_results.append(sub_content)
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
self.web_results = all_results
|
|
||||||
self._save_web_results()
|
|
||||||
|
|
||||||
# Convert web results to documents for RAG
|
|
||||||
web_docs = []
|
|
||||||
for result in all_results:
|
|
||||||
text = f"[{result['title']}]\n" + "\n".join(result['paragraphs'])
|
|
||||||
web_docs.append({"page_content": text, "metadata": {"source": result['url']}})
|
|
||||||
|
|
||||||
return all_results, web_docs
|
|
||||||
|
|
||||||
def build_retriever(self, documents):
|
|
||||||
if not documents:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Create BM25 retriever
|
|
||||||
bm25_retriever = BM25Retriever.from_documents(documents)
|
|
||||||
bm25_retriever.k = 3 # Return top 3 results
|
|
||||||
|
|
||||||
return bm25_retriever
|
|
||||||
|
|
||||||
def get_relevant_documents(self, query, documents):
|
|
||||||
retriever = self.build_retriever(documents)
|
|
||||||
if not retriever:
|
|
||||||
return []
|
|
||||||
|
|
||||||
return retriever.get_relevant_documents(query)
|
|
||||||
|
|
||||||
def extract_context_from_documents(self, query):
|
|
||||||
if not self.documents:
|
|
||||||
return None
|
|
||||||
|
|
||||||
relevant_docs = self.get_relevant_documents(query, self.documents)
|
|
||||||
|
|
||||||
if not relevant_docs:
|
|
||||||
return None
|
|
||||||
|
|
||||||
context = "\n\n".join([doc.page_content for doc in relevant_docs])
|
|
||||||
return context
|
|
||||||
|
|
||||||
def extract_context_from_web(self, web_results, web_docs, query):
|
|
||||||
if not web_results or not web_docs:
|
|
||||||
return None, []
|
|
||||||
|
|
||||||
# Try to use the retriever for better results
|
|
||||||
if web_docs:
|
|
||||||
relevant_docs = self.get_relevant_documents(query, web_docs)
|
|
||||||
if relevant_docs:
|
|
||||||
context = "\n\n".join([doc.page_content for doc in relevant_docs])
|
|
||||||
sources = [doc.metadata.get("source", "") for doc in relevant_docs if "source" in doc.metadata]
|
|
||||||
return context, sources
|
|
||||||
|
|
||||||
# Fall back to simple extraction if retriever fails
|
|
||||||
contexts = []
|
|
||||||
sources = []
|
|
||||||
|
|
||||||
for doc in web_results:
|
|
||||||
context_text = "\n".join(doc["paragraphs"])
|
|
||||||
contexts.append(f"[{doc['title']}] {context_text}")
|
|
||||||
sources.append(doc['url'])
|
|
||||||
|
|
||||||
context = "\n\n".join(contexts)
|
|
||||||
return context, sources
|
|
||||||
|
|
||||||
def get_context(query, crawl_params=None):
|
|
||||||
"""
|
|
||||||
سیستم RAG مدولار برای پاسخگویی به سوالات با استفاده از اسناد و جستجوی وب
|
|
||||||
|
|
||||||
پارامترها:
|
|
||||||
query (str): سوال به زبان فارسی
|
|
||||||
crawl_params (dict, optional): پارامترهای خزش وب
|
|
||||||
- max_depth: حداکثر عمق خزش
|
|
||||||
- max_links_per_page: حداکثر تعداد لینکهای استخراج شده از هر صفحه
|
|
||||||
- max_paragraphs: حداکثر تعداد پاراگرافهای استخراج شده از هر صفحه
|
|
||||||
- num_search_results: تعداد نتایج جستجو
|
|
||||||
|
|
||||||
خروجی:
|
|
||||||
dict: نتیجه جستجو شامل متن و منابع
|
|
||||||
"""
|
|
||||||
rag = ModularRAG()
|
|
||||||
|
|
||||||
# Configure crawling parameters if provided
|
|
||||||
if crawl_params:
|
|
||||||
if 'max_depth' in crawl_params:
|
|
||||||
rag.max_depth = crawl_params['max_depth']
|
|
||||||
if 'max_links_per_page' in crawl_params:
|
|
||||||
rag.max_links_per_page = crawl_params['max_links_per_page']
|
|
||||||
if 'max_paragraphs' in crawl_params:
|
|
||||||
rag.max_paragraphs = crawl_params['max_paragraphs']
|
|
||||||
if 'num_search_results' in crawl_params:
|
|
||||||
rag.num_search_results = crawl_params['num_search_results']
|
|
||||||
|
|
||||||
# First try to get context from documents
|
|
||||||
doc_context = rag.extract_context_from_documents(query)
|
|
||||||
|
|
||||||
if doc_context:
|
|
||||||
return {
|
|
||||||
"has_context": True,
|
|
||||||
"context": doc_context,
|
|
||||||
"source": "documents",
|
|
||||||
"language": "fa"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Fall back to web search
|
|
||||||
web_results, web_docs = rag.crawl_web(query)
|
|
||||||
|
|
||||||
if web_results:
|
|
||||||
web_context, sources = rag.extract_context_from_web(web_results, web_docs, query)
|
|
||||||
return {
|
|
||||||
"has_context": True,
|
|
||||||
"context": web_context,
|
|
||||||
"source": "web",
|
|
||||||
"sources": sources,
|
|
||||||
"language": "fa"
|
|
||||||
}
|
|
||||||
|
|
||||||
# No context found
|
|
||||||
return {
|
|
||||||
"has_context": False,
|
|
||||||
"context": "متأسفانه اطلاعاتی در مورد سوال شما یافت نشد.",
|
|
||||||
"source": "none",
|
|
||||||
"language": "fa"
|
|
||||||
}
|
|
||||||
|
|||||||
204
hybrid.py
Normal file
204
hybrid.py
Normal file
@ -0,0 +1,204 @@
|
|||||||
|
import os
|
||||||
|
import nltk
|
||||||
|
from langchain_community.document_loaders import PDFPlumberLoader, WebBaseLoader
|
||||||
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
|
from langchain_core.vectorstores import InMemoryVectorStore
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
from langchain_ollama import OllamaEmbeddings, ChatOllama
|
||||||
|
from langchain_community.retrievers import BM25Retriever
|
||||||
|
from langchain.retrievers import EnsembleRetriever
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
from langgraph.graph import START, END, StateGraph
|
||||||
|
|
||||||
|
# Ensure NLTK tokenizer is available
|
||||||
|
try:
|
||||||
|
nltk.data.find('tokenizers/punkt')
|
||||||
|
except LookupError:
|
||||||
|
nltk.download('punkt')
|
||||||
|
|
||||||
|
# Initialize model and embeddings
|
||||||
|
model = ChatOllama(model="gemma3:12b", temperature=0.2)
|
||||||
|
embeddings = OllamaEmbeddings(model="gemma3:12b")
|
||||||
|
|
||||||
|
# Vector store
|
||||||
|
vector_store = InMemoryVectorStore(embeddings)
|
||||||
|
|
||||||
|
# Templates
|
||||||
|
qa_template = """
|
||||||
|
You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question.
|
||||||
|
If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.
|
||||||
|
Question: {question}
|
||||||
|
Context: {context}
|
||||||
|
Answer:
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Text splitter
|
||||||
|
def split_text(documents):
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=1000,
|
||||||
|
chunk_overlap=200,
|
||||||
|
add_start_index=True
|
||||||
|
)
|
||||||
|
return text_splitter.split_documents(documents)
|
||||||
|
|
||||||
|
# PDF handling
|
||||||
|
def load_pdf(file_path):
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
raise FileNotFoundError(f"File not found: {file_path}")
|
||||||
|
loader = PDFPlumberLoader(file_path)
|
||||||
|
documents = loader.load()
|
||||||
|
return documents
|
||||||
|
|
||||||
|
# Web page handling (using WebBaseLoader)
|
||||||
|
def load_webpage(url):
|
||||||
|
loader = WebBaseLoader(url)
|
||||||
|
documents = loader.load()
|
||||||
|
return documents
|
||||||
|
|
||||||
|
# Hybrid retriever
|
||||||
|
def build_hybrid_retriever(documents):
|
||||||
|
vector_store.clear()
|
||||||
|
vector_store.add_documents(documents)
|
||||||
|
semantic_retriever = vector_store.as_retriever(search_kwargs={"k": 3})
|
||||||
|
bm25_retriever = BM25Retriever.from_documents(documents)
|
||||||
|
bm25_retriever.k = 3
|
||||||
|
hybrid_retriever = EnsembleRetriever(
|
||||||
|
retrievers=[semantic_retriever, bm25_retriever],
|
||||||
|
weights=[0.7, 0.3]
|
||||||
|
)
|
||||||
|
return hybrid_retriever
|
||||||
|
|
||||||
|
# DuckDuckGo search implementation
|
||||||
|
def search_ddg(query, num_results=3):
|
||||||
|
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
|
||||||
|
search = DuckDuckGoSearchAPIWrapper()
|
||||||
|
results = search.results(query, num_results)
|
||||||
|
return results
|
||||||
|
|
||||||
|
# Answer question with error handling
|
||||||
|
def answer_question(question, documents):
|
||||||
|
try:
|
||||||
|
context = "\n\n".join([doc.page_content for doc in documents])
|
||||||
|
prompt = ChatPromptTemplate.from_template(qa_template)
|
||||||
|
chain = prompt | model
|
||||||
|
return chain.invoke({"question": question, "context": context}).content
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error generating answer: {e}"
|
||||||
|
|
||||||
|
# Simple RAG node for web search
|
||||||
|
class WebSearchState(TypedDict):
|
||||||
|
query: str
|
||||||
|
results: list
|
||||||
|
response: str
|
||||||
|
|
||||||
|
def web_search(state):
|
||||||
|
results = search_ddg(state["query"])
|
||||||
|
return {"results": results}
|
||||||
|
|
||||||
|
def generate_search_response(state):
|
||||||
|
try:
|
||||||
|
context = "\n\n".join([f"{r['title']}: {r['snippet']}" for r in state["results"]])
|
||||||
|
prompt = ChatPromptTemplate.from_template(qa_template)
|
||||||
|
chain = prompt | model
|
||||||
|
response = chain.invoke({"question": state["query"], "context": context})
|
||||||
|
return {"response": response.content}
|
||||||
|
except Exception as e:
|
||||||
|
return {"response": f"Error generating response: {e}"}
|
||||||
|
|
||||||
|
# Build search graph
|
||||||
|
search_graph = StateGraph(WebSearchState)
|
||||||
|
search_graph.add_node("search", web_search)
|
||||||
|
search_graph.add_node("generate", generate_search_response)
|
||||||
|
search_graph.add_edge(START, "search")
|
||||||
|
search_graph.add_edge("search", "generate")
|
||||||
|
search_graph.add_edge("generate", END)
|
||||||
|
search_workflow = search_graph.compile()
|
||||||
|
|
||||||
|
# Main command-line interface
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("Welcome to the Advanced RAG System")
|
||||||
|
print("Choose an option:")
|
||||||
|
print("1. Analyze PDF")
|
||||||
|
print("2. Crawl URL")
|
||||||
|
print("3. Search Internet")
|
||||||
|
choice = input("Enter your choice (1/2/3): ")
|
||||||
|
|
||||||
|
if choice == "1":
|
||||||
|
pdf_path = input("Enter the path to the PDF file: ").strip()
|
||||||
|
if not pdf_path:
|
||||||
|
print("Please enter a valid file path.")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
print("Processing PDF...")
|
||||||
|
documents = load_pdf(pdf_path)
|
||||||
|
if not documents:
|
||||||
|
print("No documents were loaded from the PDF. The file might be empty or not contain extractable text.")
|
||||||
|
else:
|
||||||
|
chunked_documents = split_text(documents)
|
||||||
|
if not chunked_documents:
|
||||||
|
print("No text chunks were created. The PDF might not contain any text.")
|
||||||
|
else:
|
||||||
|
retriever = build_hybrid_retriever(chunked_documents)
|
||||||
|
print(f"Processed {len(chunked_documents)} chunks")
|
||||||
|
question = input("Ask a question about the PDF: ").strip()
|
||||||
|
if not question:
|
||||||
|
print("Please enter a valid question.")
|
||||||
|
else:
|
||||||
|
print("Searching document...")
|
||||||
|
related_documents = retriever.get_relevant_documents(question)
|
||||||
|
if not related_documents:
|
||||||
|
print("No relevant documents found for the question.")
|
||||||
|
else:
|
||||||
|
answer = answer_question(question, related_documents)
|
||||||
|
print("Answer:", answer)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
|
||||||
|
elif choice == "2":
|
||||||
|
url = input("Enter the URL to analyze: ").strip()
|
||||||
|
if not url:
|
||||||
|
print("Please enter a valid URL.")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
print("Loading webpage...")
|
||||||
|
web_documents = load_webpage(url)
|
||||||
|
if not web_documents:
|
||||||
|
print("No documents were loaded from the webpage. The page might be empty or not contain extractable text.")
|
||||||
|
else:
|
||||||
|
web_chunks = split_text(web_documents)
|
||||||
|
if not web_chunks:
|
||||||
|
print("No text chunks were created. The webpage might not contain any text.")
|
||||||
|
else:
|
||||||
|
web_retriever = build_hybrid_retriever(web_chunks)
|
||||||
|
print(f"Processed {len(web_chunks)} chunks from webpage")
|
||||||
|
question = input("Ask a question about the webpage: ").strip()
|
||||||
|
if not question:
|
||||||
|
print("Please enter a valid question.")
|
||||||
|
else:
|
||||||
|
print("Analyzing content...")
|
||||||
|
web_results = web_retriever.get_relevant_documents(question)
|
||||||
|
if not web_results:
|
||||||
|
print("No relevant documents found for the question.")
|
||||||
|
else:
|
||||||
|
answer = answer_question(question, web_results)
|
||||||
|
print("Answer:", answer)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading webpage: {e}")
|
||||||
|
|
||||||
|
elif choice == "3":
|
||||||
|
query = input("Enter your search query: ").strip()
|
||||||
|
if not query:
|
||||||
|
print("Please enter a valid search query.")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
print("Searching the web...")
|
||||||
|
search_result = search_workflow.invoke({"query": query})
|
||||||
|
print("Response:", search_result["response"])
|
||||||
|
print("Sources:")
|
||||||
|
for result in search_result["results"]:
|
||||||
|
print(f"- {result['title']}: {result['link']}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during search: {e}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
print("Invalid choice")
|
||||||
198
memory.py
Normal file
198
memory.py
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
# --- Dependencies ---
|
||||||
|
# pip install langchain langchain-core langchain-ollama faiss-cpu sentence-transformers
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
from langchain_ollama import ChatOllama, OllamaEmbeddings
|
||||||
|
from langchain.memory import ConversationBufferMemory # Added for intra-session memory
|
||||||
|
from langchain_community.vectorstores import FAISS
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableParallel
|
||||||
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
|
from langchain.schema import Document # Needed for manual saving
|
||||||
|
|
||||||
|
# --- Config ---
|
||||||
|
FAISS_INDEX_PATH = "my_chatbot_memory_index" # Directory to save/load FAISS index
|
||||||
|
|
||||||
|
# --- Ollama LLM & Embeddings Setup ---
|
||||||
|
# Run in terminal: ollama pull gemma3
|
||||||
|
# Run in terminal: ollama pull nomic-embed-text
|
||||||
|
OLLAMA_LLM_MODEL = 'gemma3' # Using Gemma 3 as requested
|
||||||
|
OLLAMA_EMBED_MODEL = 'nomic-embed-text' # Recommended embedding model for Ollama
|
||||||
|
|
||||||
|
try:
|
||||||
|
llm = ChatOllama(model=OLLAMA_LLM_MODEL)
|
||||||
|
embeddings = OllamaEmbeddings(model=OLLAMA_EMBED_MODEL)
|
||||||
|
print(f"Successfully initialized Ollama: LLM='{OLLAMA_LLM_MODEL}', Embeddings='{OLLAMA_EMBED_MODEL}'")
|
||||||
|
# Optional tests removed for brevity
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error initializing Ollama components: {e}")
|
||||||
|
print(f"Ensure Ollama is running & models pulled (e.g., 'ollama pull {OLLAMA_LLM_MODEL}' and 'ollama pull {OLLAMA_EMBED_MODEL}').")
|
||||||
|
exit()
|
||||||
|
|
||||||
|
# --- Vector Store (Episodic Memory) Setup --- Persisted!
|
||||||
|
try:
|
||||||
|
if os.path.exists(FAISS_INDEX_PATH):
|
||||||
|
print(f"Loading existing FAISS index from: {FAISS_INDEX_PATH}")
|
||||||
|
vectorstore = FAISS.load_local(
|
||||||
|
FAISS_INDEX_PATH,
|
||||||
|
embeddings,
|
||||||
|
allow_dangerous_deserialization=True # Required for FAISS loading
|
||||||
|
)
|
||||||
|
retriever = vectorstore.as_retriever(search_kwargs=dict(k=3))
|
||||||
|
print("FAISS vector store loaded successfully.")
|
||||||
|
else:
|
||||||
|
print(f"No FAISS index found at {FAISS_INDEX_PATH}. Initializing new store.")
|
||||||
|
# FAISS needs at least one text to initialize.
|
||||||
|
vectorstore = FAISS.from_texts(
|
||||||
|
["Initial conversation context placeholder - Bot created"],
|
||||||
|
embeddings
|
||||||
|
)
|
||||||
|
retriever = vectorstore.as_retriever(search_kwargs=dict(k=3))
|
||||||
|
# Save the initial empty index
|
||||||
|
vectorstore.save_local(FAISS_INDEX_PATH)
|
||||||
|
print("New FAISS vector store initialized and saved.")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error initializing/loading FAISS: {e}")
|
||||||
|
print("Check permissions or delete the index directory if corrupted.")
|
||||||
|
exit()
|
||||||
|
|
||||||
|
# --- Conversation Buffer (Short-Term) Memory Setup ---
|
||||||
|
# memory_key must match the input variable in the prompt
|
||||||
|
# return_messages=True formats history as suitable list of BaseMessages
|
||||||
|
buffer_memory = ConversationBufferMemory(
|
||||||
|
memory_key="chat_history",
|
||||||
|
return_messages=True
|
||||||
|
)
|
||||||
|
# <<< ADDED: Clear buffer at the start of each script run >>>
|
||||||
|
buffer_memory.clear()
|
||||||
|
|
||||||
|
# --- Define the Prompt Template ---
|
||||||
|
# Now includes chat_history for the buffer memory
|
||||||
|
template = """You are a helpful chatbot assistant with episodic memory (from past sessions) and conversational awareness (from the current session).
|
||||||
|
Use the following relevant pieces of information:
|
||||||
|
1. Episodic Memory (Knowledge from *previous* chat sessions):
|
||||||
|
{semantic_context}
|
||||||
|
|
||||||
|
2. Chat History (What we've discussed in the *current* session):
|
||||||
|
{chat_history}
|
||||||
|
|
||||||
|
Combine this information with the current user input to generate a coherent and contextually relevant answer.
|
||||||
|
If recalling information from Episodic Memory, you can mention it stems from a past conversation if appropriate.
|
||||||
|
If no relevant context or history is found, just respond naturally to the current input.
|
||||||
|
|
||||||
|
Current Input:
|
||||||
|
User: {input}
|
||||||
|
Assistant:"""
|
||||||
|
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
input_variables=["semantic_context", "chat_history", "input"],
|
||||||
|
template=template
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Helper Function for Formatting Retrieved Docs (Episodic Memory) ---
|
||||||
|
# Formats the retrieved documents (past interactions) for the prompt
|
||||||
|
def format_retrieved_docs(docs):
|
||||||
|
# Simplified formatting: Extract core content only and label explicitly
|
||||||
|
formatted = []
|
||||||
|
for doc in docs:
|
||||||
|
content = doc.page_content
|
||||||
|
# Basic check to remove placeholder
|
||||||
|
if content not in ["Initial conversation context placeholder - Bot created"]:
|
||||||
|
# Attempt to strip "Role (timestamp): " prefix if present
|
||||||
|
if "):":
|
||||||
|
content = content.split("):", 1)[-1].strip()
|
||||||
|
if content: # Ensure content is not empty after stripping
|
||||||
|
formatted.append(f"Recalled from a past session: {content}")
|
||||||
|
# Use a double newline to separate recalled memories clearly
|
||||||
|
return "\n\n".join(formatted) if formatted else "No relevant memories found from past sessions."
|
||||||
|
|
||||||
|
|
||||||
|
# --- Chain Definition using LCEL ---
|
||||||
|
|
||||||
|
# Function to load episodic memory (FAISS context)
|
||||||
|
def load_episodic_memory(input_dict):
|
||||||
|
query = input_dict.get("input", "")
|
||||||
|
docs = retriever.invoke(query)
|
||||||
|
return format_retrieved_docs(docs)
|
||||||
|
|
||||||
|
# Function to save episodic memory (and persist FAISS index)
|
||||||
|
def save_episodic_memory_step(inputs_outputs):
|
||||||
|
user_input = inputs_outputs.get("input", "")
|
||||||
|
llm_output = inputs_outputs.get("output", "")
|
||||||
|
|
||||||
|
if user_input and llm_output:
|
||||||
|
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
docs_to_add = [
|
||||||
|
Document(page_content=f"User ({timestamp}): {user_input}"),
|
||||||
|
Document(page_content=f"Assistant ({timestamp}): {llm_output}")
|
||||||
|
]
|
||||||
|
vectorstore.add_documents(docs_to_add)
|
||||||
|
vectorstore.save_local(FAISS_INDEX_PATH) # Persist index after adding
|
||||||
|
# print(f"DEBUG: Saved to FAISS index: {FAISS_INDEX_PATH}")
|
||||||
|
return inputs_outputs # Pass the dict through for potential further steps
|
||||||
|
|
||||||
|
|
||||||
|
# Define the core chain logic
|
||||||
|
chain_core = (
|
||||||
|
RunnablePassthrough.assign(
|
||||||
|
semantic_context=RunnableLambda(load_episodic_memory),
|
||||||
|
chat_history=RunnableLambda(lambda x: buffer_memory.load_memory_variables(x)['chat_history'])
|
||||||
|
)
|
||||||
|
| prompt
|
||||||
|
| llm
|
||||||
|
| StrOutputParser()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wrap the core logic to handle memory updates
|
||||||
|
def run_chain(input_dict):
|
||||||
|
user_input = input_dict['input']
|
||||||
|
|
||||||
|
# Invoke the core chain to get the response
|
||||||
|
llm_response = chain_core.invoke({"input": user_input})
|
||||||
|
|
||||||
|
# Prepare data for saving
|
||||||
|
save_data = {"input": user_input, "output": llm_response}
|
||||||
|
|
||||||
|
# Save to episodic memory (FAISS)
|
||||||
|
save_episodic_memory_step(save_data)
|
||||||
|
|
||||||
|
# Save to buffer memory
|
||||||
|
buffer_memory.save_context({"input": user_input}, {"output": llm_response})
|
||||||
|
|
||||||
|
return llm_response
|
||||||
|
|
||||||
|
|
||||||
|
# --- Chat Loop ---
|
||||||
|
print(f"\nChatbot Ready! Using Ollama ('{OLLAMA_LLM_MODEL}' chat, '{OLLAMA_EMBED_MODEL}' embed)")
|
||||||
|
print(f"Episodic memory stored in: {FAISS_INDEX_PATH}")
|
||||||
|
print("Type 'quit', 'exit', or 'bye' to end the conversation.")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
user_text = input("You: ")
|
||||||
|
if user_text.lower() in ["quit", "exit", "bye"]:
|
||||||
|
# Optionally clear buffer memory on exit if desired
|
||||||
|
buffer_memory.clear()
|
||||||
|
print("Chatbot: Goodbye!")
|
||||||
|
break
|
||||||
|
if not user_text:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use the wrapper function to handle the chain invocation and memory updates
|
||||||
|
response = run_chain({"input": user_text})
|
||||||
|
print(f"Chatbot: {response}")
|
||||||
|
|
||||||
|
# Optional debug: View buffer memory
|
||||||
|
# print("DEBUG: Buffer Memory:", buffer_memory.load_memory_variables({}))
|
||||||
|
# Optional debug: Check vector store size
|
||||||
|
# print(f"DEBUG: Vector store size: {vectorstore.index.ntotal}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\nAn error occurred during the chat chain: {e}")
|
||||||
|
# Add more detailed error logging if needed
|
||||||
|
import traceback
|
||||||
|
print(traceback.format_exc())
|
||||||
|
|
||||||
|
# --- End of Script ---
|
||||||
Loading…
x
Reference in New Issue
Block a user