Upload files to "/"

This commit is contained in:
MasihMoafi 2025-05-02 11:12:23 +00:00
parent dfdeab2855
commit 4c9da22a39
5 changed files with 1103 additions and 313 deletions

267
Multimodal.py Normal file
View File

@ -0,0 +1,267 @@
import os
import subprocess
# Clear proxy settings
def clear_proxy_settings():
for var in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]:
if var in os.environ:
del os.environ[var]
clear_proxy_settings()
import os
import tempfile
import subprocess
from datetime import datetime
import streamlit as st
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_ollama import OllamaEmbeddings
from langchain_ollama.llms import OllamaLLM
from langchain_text_splitters import RecursiveCharacterTextSplitter
from unstructured.partition.pdf import partition_pdf
from unstructured.partition.utils.constants import PartitionStrategy
from search_utils import duckduckgo_search, rank_results
template = """
تو یک دستیار هستی که از یک داده های متنی و تصویری استفاده میکنی تا به سوالات کاربر به زبان فارسی سلیس پاسخ بدی.
Question: {question}
Context: {context}
Answer:
"""
pdfs_directory = 'multi-modal-rag/pdfs/'
figures_directory = 'multi-modal-rag/figures/'
images_directory = 'multi-modal-rag/images/'
videos_directory = 'multi-modal-rag/videos/'
audio_directory = 'multi-modal-rag/audio/'
frames_directory = 'multi-modal-rag/frames/'
# Create directories if they don't exist
os.makedirs(pdfs_directory, exist_ok=True)
os.makedirs(figures_directory, exist_ok=True)
os.makedirs(images_directory, exist_ok=True)
os.makedirs(videos_directory, exist_ok=True)
os.makedirs(audio_directory, exist_ok=True)
os.makedirs(frames_directory, exist_ok=True)
embeddings = OllamaEmbeddings(model="llama3.2")
vector_store = InMemoryVectorStore(embeddings)
model = OllamaLLM(model="gemma3")
def upload_pdf(file):
with open(pdfs_directory + file.name, "wb") as f:
f.write(file.getbuffer())
def upload_image(file):
with open(images_directory + file.name, "wb") as f:
f.write(file.getbuffer())
return images_directory + file.name
def upload_video(file):
file_path = videos_directory + file.name
with open(file_path, "wb") as f:
f.write(file.getbuffer())
return file_path
def upload_audio(file):
file_path = audio_directory + file.name
with open(file_path, "wb") as f:
f.write(file.getbuffer())
return file_path
def load_pdf(file_path):
elements = partition_pdf(
file_path,
strategy=PartitionStrategy.HI_RES,
extract_image_block_types=["Image", "Table"],
extract_image_block_output_dir=figures_directory
)
text_elements = [element.text for element in elements if element.category not in ["Image", "Table"]]
for file in os.listdir(figures_directory):
extracted_text = extract_text(figures_directory + file)
text_elements.append(extracted_text)
return "\n\n".join(text_elements)
def extract_frames(video_path, num_frames=5):
"""Extract frames from video file and save them to frames directory"""
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
base_name = os.path.basename(video_path).split('.')[0]
frame_paths = []
# Extract frames using ffmpeg
for i in range(num_frames):
frame_path = f"{frames_directory}{base_name}_{timestamp}_{i}.jpg"
cmd = [
'ffmpeg', '-i', video_path,
'-ss', str(i * (1/num_frames)), '-vframes', '1',
'-q:v', '2', frame_path, '-y'
]
try:
subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
frame_paths.append(frame_path)
except subprocess.CalledProcessError:
st.warning(f"Failed to extract frame {i} from video")
return frame_paths
def process_audio(audio_path):
"""Process audio file using the model"""
audio_description = model.invoke(
f"Describe what you hear in this audio file: {os.path.basename(audio_path)}"
)
return f"Audio file: {os.path.basename(audio_path)}. Description: {audio_description}"
def extract_text(file_path):
model_with_image_context = model.bind(images=[file_path])
return model_with_image_context.invoke("Tell me what do you see in this picture.")
def split_text(text):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
add_start_index=True
)
return text_splitter.split_text(text)
def index_docs(texts):
vector_store.add_texts(texts)
def retrieve_docs(query):
return vector_store.similarity_search(query)
def answer_question(question, documents):
local_context = "\n\n".join([doc.page_content for doc in documents])
prompt = ChatPromptTemplate.from_template(template)
chain = prompt | model
return chain.invoke({"question": question, "context": local_context})
# Sidebar for upload options
st.sidebar.title("Upload Documents")
upload_option = st.sidebar.radio("Choose upload type:", ["PDF", "Image", "Video", "Audio", "Search"])
if upload_option == "Search":
st.title("Web Search with BM25 Ranking")
search_query = st.text_input("Enter your search query:")
if search_query:
with st.spinner("Searching and ranking results..."):
# Get search results
search_results = duckduckgo_search(search_query, max_results=10)
if search_results:
# Rank results using BM25
ranked_results = rank_results(search_query, search_results)
# Display results
st.subheader("Ranked Search Results")
for i, result in enumerate(ranked_results):
with st.expander(f"{i+1}. {result.title}"):
st.write(f"**Snippet:** {result.snippet}")
st.write(f"**URL:** {result.url}")
# Option to ask about search results
st.subheader("Ask about these results")
question = st.text_input("Enter your question about the search results:")
if question:
# Prepare context from top results
context = "\n\n".join([f"Title: {r.title}\nSnippet: {r.snippet}" for r in ranked_results[:3]])
# Use the model to answer
prompt = ChatPromptTemplate.from_template(template)
chain = prompt | model
with st.spinner("Generating answer..."):
response = chain.invoke({"question": question, "context": context})
st.markdown("### Answer")
st.write(response.content)
else:
st.warning("No search results found")
elif upload_option == "PDF":
uploaded_file = st.file_uploader(
"Upload PDF",
type="pdf",
accept_multiple_files=False
)
if uploaded_file:
upload_pdf(uploaded_file)
with st.spinner("Processing PDF..."):
text = load_pdf(pdfs_directory + uploaded_file.name)
chunked_texts = split_text(text)
index_docs(chunked_texts)
st.success("PDF processed successfully!")
elif upload_option == "Image":
uploaded_image = st.file_uploader(
"Upload Image",
type=["jpg", "jpeg", "png"],
accept_multiple_files=False
)
if uploaded_image:
image_path = upload_image(uploaded_image)
st.image(image_path, caption="Uploaded Image", use_column_width=True)
with st.spinner("Processing image..."):
image_description = extract_text(image_path)
index_docs([image_description])
st.success("Image processed and added to knowledge base")
elif upload_option == "Video":
uploaded_video = st.file_uploader(
"Upload Video",
type=["mp4", "avi", "mov", "mkv"],
accept_multiple_files=False
)
if uploaded_video:
video_path = upload_video(uploaded_video)
st.video(video_path)
with st.spinner("Processing video frames..."):
frame_paths = extract_frames(video_path)
video_descriptions = []
for frame_path in frame_paths:
st.image(frame_path, caption=f"Frame from video", width=200)
frame_description = extract_text(frame_path)
video_descriptions.append(frame_description)
# Add a combined description
combined_description = f"Video file: {uploaded_video.name}. Content description: " + " ".join(video_descriptions)
index_docs([combined_description])
st.success("Video processed and added to knowledge base")
else: # Audio option
uploaded_audio = st.file_uploader(
"Upload Audio",
type=["mp3", "wav", "ogg"],
accept_multiple_files=False
)
if uploaded_audio:
audio_path = upload_audio(uploaded_audio)
st.audio(audio_path)
with st.spinner("Processing audio..."):
# For audio, we'll use the model directly without visual context
audio_description = process_audio(audio_path)
index_docs([audio_description])
st.success("Audio processed and added to knowledge base")
# Chat interface
question = st.chat_input()
if question:
st.chat_message("user").write(question)
related_documents = retrieve_docs(question)
answer = answer_question(question, related_documents)
st.chat_message("assistant").write(answer)

View File

@ -1,116 +1,332 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import pickle
import re
import json
import nltk
import ssl
import argparse
import requests
import time
from bs4 import BeautifulSoup
from urllib.parse import quote
from langchain_community.document_loaders import PDFPlumberLoader, WebBaseLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.retrievers import BM25Retriever
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_core.documents import Document
import traceback
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download('punkt')
# Disable SSL warnings and proxy settings
ssl._create_default_https_context = ssl._create_unverified_context
requests.packages.urllib3.disable_warnings()
class ModularRAG:
def __init__(self):
self.storage_path = "./rag_data"
def clear_proxy_settings():
"""Remove proxy environment variables that might cause connection issues."""
for var in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]:
if var in os.environ:
print(f"Removing proxy env var: {var}")
del os.environ[var]
if not os.path.exists(self.storage_path):
os.makedirs(self.storage_path)
os.makedirs(os.path.join(self.storage_path, "documents"))
os.makedirs(os.path.join(self.storage_path, "web_results"))
# Run at module load time
clear_proxy_settings()
self.documents = []
self.web_results = []
# Configuration
DOCUMENT_PATHS = [
r'doc1.txt',
r'doc2.txt',
r'doc3.txt',
r'doc4.txt',
r'doc5.txt',
r'doc6.txt'
]
EMBEDDING_MODEL = 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2'
LLM_MODEL = 'gemma3'
CHUNK_SIZE = 1000
OVERLAP = 200
CHROMA_PERSIST_DIR = 'chroma_db'
# Web crawler settings
self.headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
# Confidence thresholds
THRESHOLDS = {
'direct_answer': 0.7,
'rag_confidence': 0.6,
'web_search': 0.5
}
def query_llm(prompt, model='gemma3'):
"""Query the LLM model directly using Ollama API."""
try:
ollama_endpoint = "http://localhost:11434/api/generate"
payload = {
"model": model,
"prompt": prompt,
"stream": False
}
self.num_search_results = 10
self.max_depth = 2
self.max_links_per_page = 5
self.max_paragraphs = 5
response = requests.post(ollama_endpoint, json=payload)
self._load_saved_data()
def _load_saved_data(self):
doc_path = os.path.join(self.storage_path, "documents", "docs.pkl")
web_path = os.path.join(self.storage_path, "web_results", "web.json")
if os.path.exists(doc_path):
try:
with open(doc_path, 'rb') as f:
self.documents = pickle.load(f)
if response.status_code == 200:
result = response.json()
return result.get('response', '')
else:
print(f"Ollama API error: {response.status_code}")
return f"Error calling Ollama API: {response.status_code}"
except Exception as e:
print(f"خطا در بارگیری اسناد: {e}")
print(f"Error querying LLM: {e}")
return f"Error: {str(e)}"
if os.path.exists(web_path):
class BM25Retriever:
"""BM25 retriever implementation for text similarity search"""
@classmethod
def from_documents(cls, documents):
"""Create a BM25 retriever from documents"""
retriever = cls()
retriever.documents = documents
retriever.k = 4
return retriever
def get_relevant_documents(self, query):
"""Get relevant documents using BM25 algorithm"""
# Simple BM25-like implementation
scores = []
query_terms = set(re.findall(r'\b\w+\b', query.lower()))
for doc in self.documents:
doc_terms = set(re.findall(r'\b\w+\b', doc.page_content.lower()))
# Calculate term overlap as a simple approximation of BM25
overlap = len(query_terms.intersection(doc_terms))
scores.append((doc, overlap))
# Sort by score and return top k
sorted_docs = [doc for doc, score in sorted(scores, key=lambda x: x[1], reverse=True)]
return sorted_docs[:self.k]
class HybridRetriever:
"""Hybrid retriever combining BM25 and vector search with configurable weights"""
def __init__(self, vector_retriever, bm25_retriever, vector_weight=0.3):
"""Initialize with separate retrievers and weights"""
self._vector_retriever = vector_retriever
self._bm25_retriever = bm25_retriever
self._vector_weight = vector_weight
self._bm25_weight = 1.0 - vector_weight
def get_relevant_documents(self, query):
"""Get relevant documents using weighted combination of retrievers"""
try:
with open(web_path, 'r', encoding='utf-8') as f:
self.web_results = json.load(f)
# Get results from both retrievers
vector_docs = self._vector_retriever.get_relevant_documents(query)
bm25_docs = self._bm25_retriever.get_relevant_documents(query)
# Create dictionary to track unique documents and their scores
doc_dict = {}
# Add vector docs with their weights
for i, doc in enumerate(vector_docs):
# Score based on position (inverse rank)
score = (len(vector_docs) - i) * self._vector_weight
doc_id = doc.page_content[:50] # Use first 50 chars as a simple ID
if doc_id in doc_dict:
doc_dict[doc_id]["score"] += score
else:
doc_dict[doc_id] = {"doc": doc, "score": score}
# Add BM25 docs with their weights
for i, doc in enumerate(bm25_docs):
# Score based on position (inverse rank)
score = (len(bm25_docs) - i) * self._bm25_weight
doc_id = doc.page_content[:50] # Use first 50 chars as a simple ID
if doc_id in doc_dict:
doc_dict[doc_id]["score"] += score
else:
doc_dict[doc_id] = {"doc": doc, "score": score}
# Sort by combined score (highest first)
sorted_docs = sorted(doc_dict.values(), key=lambda x: x["score"], reverse=True)
# Return just the document objects
return [item["doc"] for item in sorted_docs]
except Exception as e:
print(f"خطا در بارگیری نتایج وب: {e}")
def _save_documents(self):
doc_path = os.path.join(self.storage_path, "documents", "docs.pkl")
try:
with open(doc_path, 'wb') as f:
pickle.dump(self.documents, f)
except Exception as e:
print(f"خطا در ذخیره‌سازی اسناد: {e}")
def _save_web_results(self):
web_path = os.path.join(self.storage_path, "web_results", "web.json")
try:
with open(web_path, 'w', encoding='utf-8') as f:
json.dump(self.web_results, f, ensure_ascii=False, indent=2)
except Exception as e:
print(f"خطا در ذخیره‌سازی نتایج وب: {e}")
def load_pdf(self, file_path):
if not os.path.exists(file_path):
raise FileNotFoundError(f"فایل یافت نشد: {file_path}")
try:
loader = PDFPlumberLoader(file_path)
documents = loader.load()
if documents:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
add_start_index=True
)
chunked_docs = text_splitter.split_documents(documents)
self.documents.extend(chunked_docs)
self._save_documents()
return len(chunked_docs)
return 0
except Exception as e:
raise Exception(f"خطا در بارگیری PDF: {e}")
def search_duckduckgo(self, query, num_results=None):
if num_results is None:
num_results = self.num_search_results
try:
search_url = f"https://html.duckduckgo.com/html/?q={quote(query)}"
response = requests.get(search_url, headers=self.headers, timeout=10)
if response.status_code != 200:
print(f"خطا در جستجوی وب: HTTP {response.status_code}")
print(f"Error in hybrid retrieval: {e}")
return []
class AgenticQASystem:
"""QA system implementing the specified architecture"""
def __init__(self):
"""Initialize the QA system with retrievers"""
# Load embeddings
self.embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
# Load documents and retrievers
self.documents = self.load_documents()
self.retriever = self.initialize_retriever()
def load_documents(self):
"""Load documents from configured paths with sliding window chunking"""
print("Loading documents...")
docs = []
for path in DOCUMENT_PATHS:
try:
with open(path, 'r', encoding='utf-8') as f:
text = re.sub(r'\s+', ' ', f.read()).strip()
# Sliding window chunking
chunks = [text[i:i+CHUNK_SIZE] for i in range(0, len(text), CHUNK_SIZE - OVERLAP)]
for chunk in chunks:
docs.append(Document(
page_content=chunk,
metadata={"source": os.path.basename(path)}
))
except Exception as e:
print(f"Error loading document {path}: {e}")
print(f"Loaded {len(docs)} document chunks")
return docs
def initialize_retriever(self):
"""Initialize the hybrid retriever with BM25 and direct Chroma queries"""
if not self.documents:
print("No documents loaded, retriever initialization failed")
return None
try:
# Create BM25 retriever
bm25_retriever = BM25Retriever.from_documents(self.documents)
bm25_retriever.k = 4 # Top k results to retrieve
# Initialize vector store with KNN search
import shutil
if os.path.exists(CHROMA_PERSIST_DIR):
print(f"Removing existing Chroma DB to prevent dimension mismatch")
shutil.rmtree(CHROMA_PERSIST_DIR)
# Create vector store directly from Chroma
print("Creating vector store...")
vector_store = Chroma.from_documents(
documents=self.documents,
embedding=self.embeddings,
persist_directory=CHROMA_PERSIST_DIR
)
vector_retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 4})
print(f"Vector retriever created: {type(vector_retriever)}")
# Create hybrid retriever - BM25 (70%) and Vector (30%)
print("Creating hybrid retriever")
hybrid_retriever = HybridRetriever(vector_retriever, bm25_retriever, vector_weight=0.3)
print("Hybrid retriever initialized successfully")
return hybrid_retriever
except Exception as e:
print(f"Error initializing retriever: {e}")
traceback.print_exc()
return None
def estimate_confidence(self, text, query, context=None):
"""Estimate confidence of response"""
# Start with baseline confidence
confidence = 0.5
# Check for uncertainty markers
uncertainty_phrases = [
"نمی‌دانم", "مطمئن نیستم", "ممکن است", "شاید", "احتمالاً",
"فکر می‌کنم", "به نظر می‌رسد"
]
if any(phrase in text.lower() for phrase in uncertainty_phrases):
confidence -= 0.2
# Check for question relevance
query_words = set(re.findall(r'\b\w+\b', query.lower()))
text_words = set(re.findall(r'\b\w+\b', text.lower()))
# Calculate overlap between query and response
if query_words:
overlap_ratio = len(query_words.intersection(text_words)) / len(query_words)
if overlap_ratio > 0.5:
confidence += 0.2
elif overlap_ratio < 0.2:
confidence -= 0.2
# If context provided, check context relevance
if context:
context_words = set(re.findall(r'\b\w+\b', context.lower()))
if context_words:
context_overlap = len(context_words.intersection(text_words)) / len(context_words)
if context_overlap > 0.3:
confidence += 0.2
else:
confidence -= 0.1
# Ensure confidence is within [0,1]
return max(0.0, min(1.0, confidence))
def check_direct_knowledge(self, query):
"""Check if the LLM can answer directly from its knowledge"""
print("Checking LLM's direct knowledge...")
prompt = f"""به این سوال با استفاده از دانش خود پاسخ دهید. فقط به زبان فارسی پاسخ دهید.
سوال: {query}
پاسخ فارسی:"""
response = query_llm(prompt, model=LLM_MODEL)
confidence = self.estimate_confidence(response, query)
print(f"LLM direct knowledge confidence: {confidence:.2f}")
return response, confidence
def rag_query(self, query):
"""Use RAG to retrieve and generate answer"""
if not self.retriever:
print("Retriever not initialized, skipping RAG")
return None, 0.0
print("Retrieving documents for RAG...")
# Retrieve relevant documents
docs = self.retriever.get_relevant_documents(query)
if not docs:
print("No relevant documents found")
return None, 0.0
print(f"Retrieved {len(docs)} relevant documents")
# Prepare context
context = "\n\n".join([doc.page_content for doc in docs])
sources = [doc.metadata.get("source", "Unknown") for doc in docs]
# Query LLM with context
prompt = f"""با توجه به اطلاعات زیر، به سوال پاسخ دهید. فقط به زبان فارسی پاسخ دهید.
اطلاعات:
{context}
سوال: {query}
پاسخ فارسی:"""
response = query_llm(prompt, model=LLM_MODEL)
confidence = self.estimate_confidence(response, query, context)
print(f"RAG confidence: {confidence:.2f}")
return {
"response": response,
"confidence": confidence,
"sources": list(set(sources))
}, confidence
def web_search(self, query):
"""Search the web for an answer"""
print("Searching web for answer...")
# Search DuckDuckGo
search_url = f"https://html.duckduckgo.com/html/?q={quote(query)}"
response = requests.get(search_url, verify=False, timeout=10)
if response.status_code != 200:
print(f"Error searching web: HTTP {response.status_code}")
return None, 0.0
# Parse results
soup = BeautifulSoup(response.text, 'html.parser')
results = []
for element in soup.select('.result__url, .result__a'):
for element in soup.select('.result__url, .result__a')[:4]:
href = element.get('href') if 'href' in element.attrs else None
if href and not href.startswith('/') and (href.startswith('http://') or href.startswith('https://')):
@ -120,221 +336,125 @@ class ModularRAG:
if href and not href.startswith('/'):
results.append(href)
unique_results = list(set(results))
return unique_results[:num_results]
except Exception as e:
print(f"خطا در جستجوی DuckDuckGo: {e}")
return []
def crawl_page(self, url, depth=0):
if depth > self.max_depth:
return None, []
if not results:
print("No web results found")
return None, 0.0
# Crawl top results
web_content = []
for url in results[:3]:
try:
response = requests.get(url, headers=self.headers, timeout=10)
response.raise_for_status()
headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
page = requests.get(url, headers=headers, timeout=10, verify=False)
page.raise_for_status()
soup = BeautifulSoup(response.text, 'html.parser')
soup = BeautifulSoup(page.text, 'html.parser')
title = soup.title.string if soup.title else "بدون عنوان"
# Remove non-content elements
for tag in ['script', 'style', 'nav', 'footer', 'header']:
for element in soup.find_all(tag):
element.decompose()
paragraphs = []
for p in soup.find_all('p'):
text = p.get_text(strip=True)
if len(text) > 50:
paragraphs.append(text)
if len(paragraphs) >= self.max_paragraphs:
break
links = []
for a in soup.find_all('a', href=True):
href = a['href']
if href.startswith('http') and href != url:
links.append(href)
if len(links) >= self.max_links_per_page:
break
content = {
"url": url,
"title": title,
"paragraphs": paragraphs
}
return content, links
# Get paragraphs
paragraphs = [p.get_text(strip=True) for p in soup.find_all('p')
if len(p.get_text(strip=True)) > 20]
if paragraphs:
web_content.append(f"[Source: {url}] " + " ".join(paragraphs[:5]))
except Exception as e:
print(f"خطا در خزش صفحه {url}: {e}")
return None, []
print(f"Error crawling {url}: {e}")
def crawl_website(self, start_url, max_pages=10):
visited = set()
to_visit = [start_url]
contents = []
if not web_content:
print("No useful content found from web results")
return None, 0.0
while to_visit and len(visited) < max_pages:
current_url = to_visit.pop(0)
# Query LLM with web content
context = "\n\n".join(web_content)
prompt = f"""با توجه به اطلاعات زیر که از وب بدست آمده، به سوال پاسخ دهید. فقط به زبان فارسی پاسخ دهید.
if current_url in visited:
اطلاعات:
{context}
سوال: {query}
پاسخ فارسی:"""
response = query_llm(prompt, model=LLM_MODEL)
confidence = self.estimate_confidence(response, query, context)
print(f"Web search confidence: {confidence:.2f}")
return {
"response": response,
"confidence": confidence,
"sources": results[:3]
}, confidence
def get_answer(self, query):
"""Main method to get an answer following the specified architecture"""
print(f"Processing query: {query}")
# STEP 1: Try direct LLM knowledge
direct_response, direct_confidence = self.check_direct_knowledge(query)
if direct_confidence >= THRESHOLDS['direct_answer']:
print("Using direct LLM knowledge (high confidence)")
return f"{direct_response}\n\n[Source: LLM Knowledge, Confidence: {direct_confidence:.2f}]"
# STEP 2: Try RAG with local documents
rag_result, rag_confidence = self.rag_query(query)
if rag_result and rag_confidence >= THRESHOLDS['rag_confidence']:
print("Using RAG response (sufficient confidence)")
sources_text = ", ".join(rag_result["sources"][:3])
return f"{rag_result['response']}\n\n[Source: Local Documents, Confidence: {rag_confidence:.2f}, Sources: {sources_text}]"
# STEP 3: Try web search
web_result, web_confidence = self.web_search(query)
if web_result and web_confidence >= THRESHOLDS['web_search']:
print("Using web search response (sufficient confidence)")
sources_text = ", ".join(web_result["sources"])
return f"{web_result['response']}\n\n[Source: Web Search, Confidence: {web_confidence:.2f}, Sources: {sources_text}]"
# STEP 4: Fall back to direct response with warning
print("No high-confidence source found, using direct response with warning")
return f"{direct_response}\n\n[Warning: Low confidence ({direct_confidence:.2f}). Please verify information.]"
# Simple API functions
def get_answer(query):
"""Get an answer for a query"""
system = AgenticQASystem()
return system.get_answer(query)
# Main entry point
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="QA System")
mode_group = parser.add_mutually_exclusive_group(required=True)
mode_group.add_argument("--query", "-q", help="Query to answer")
mode_group.add_argument("--interactive", "-i", action="store_true", help="Run in interactive chat mode")
mode_group.add_argument("--test", "-t", action="store_true", help="Run tests")
args = parser.parse_args()
if args.interactive:
# Simple interactive mode without memory
qa_system = AgenticQASystem()
print("=== QA System ===")
print("Type 'exit' or 'quit' to end")
while True:
user_input = input("\nYou: ")
if not user_input.strip():
continue
content, links = self.crawl_page(current_url)
if user_input.lower() in ['exit', 'quit', 'خروج']:
break
visited.add(current_url)
if content and content["paragraphs"]:
contents.append(content)
for link in links:
if link not in visited and link not in to_visit:
to_visit.append(link)
time.sleep(1)
return contents
def crawl_web(self, query):
urls = self.search_duckduckgo(query)
if not urls:
print("هیچ نتیجه‌ای یافت نشد.")
return []
all_results = []
for url in urls[:3]: # Limit to first 3 URLs for efficiency
content, links = self.crawl_page(url)
if content and content["paragraphs"]:
all_results.append(content)
# Follow links from the main page (recursive crawling)
for link in links[:2]: # Limit to first 2 links
sub_content, _ = self.crawl_page(link, depth=1)
if sub_content and sub_content["paragraphs"]:
all_results.append(sub_content)
time.sleep(1)
time.sleep(1)
self.web_results = all_results
self._save_web_results()
# Convert web results to documents for RAG
web_docs = []
for result in all_results:
text = f"[{result['title']}]\n" + "\n".join(result['paragraphs'])
web_docs.append({"page_content": text, "metadata": {"source": result['url']}})
return all_results, web_docs
def build_retriever(self, documents):
if not documents:
return None
# Create BM25 retriever
bm25_retriever = BM25Retriever.from_documents(documents)
bm25_retriever.k = 3 # Return top 3 results
return bm25_retriever
def get_relevant_documents(self, query, documents):
retriever = self.build_retriever(documents)
if not retriever:
return []
return retriever.get_relevant_documents(query)
def extract_context_from_documents(self, query):
if not self.documents:
return None
relevant_docs = self.get_relevant_documents(query, self.documents)
if not relevant_docs:
return None
context = "\n\n".join([doc.page_content for doc in relevant_docs])
return context
def extract_context_from_web(self, web_results, web_docs, query):
if not web_results or not web_docs:
return None, []
# Try to use the retriever for better results
if web_docs:
relevant_docs = self.get_relevant_documents(query, web_docs)
if relevant_docs:
context = "\n\n".join([doc.page_content for doc in relevant_docs])
sources = [doc.metadata.get("source", "") for doc in relevant_docs if "source" in doc.metadata]
return context, sources
# Fall back to simple extraction if retriever fails
contexts = []
sources = []
for doc in web_results:
context_text = "\n".join(doc["paragraphs"])
contexts.append(f"[{doc['title']}] {context_text}")
sources.append(doc['url'])
context = "\n\n".join(contexts)
return context, sources
def get_context(query, crawl_params=None):
"""
سیستم RAG مدولار برای پاسخگویی به سوالات با استفاده از اسناد و جستجوی وب
پارامترها:
query (str): سوال به زبان فارسی
crawl_params (dict, optional): پارامترهای خزش وب
- max_depth: حداکثر عمق خزش
- max_links_per_page: حداکثر تعداد لینکهای استخراج شده از هر صفحه
- max_paragraphs: حداکثر تعداد پاراگرافهای استخراج شده از هر صفحه
- num_search_results: تعداد نتایج جستجو
خروجی:
dict: نتیجه جستجو شامل متن و منابع
"""
rag = ModularRAG()
# Configure crawling parameters if provided
if crawl_params:
if 'max_depth' in crawl_params:
rag.max_depth = crawl_params['max_depth']
if 'max_links_per_page' in crawl_params:
rag.max_links_per_page = crawl_params['max_links_per_page']
if 'max_paragraphs' in crawl_params:
rag.max_paragraphs = crawl_params['max_paragraphs']
if 'num_search_results' in crawl_params:
rag.num_search_results = crawl_params['num_search_results']
# First try to get context from documents
doc_context = rag.extract_context_from_documents(query)
if doc_context:
return {
"has_context": True,
"context": doc_context,
"source": "documents",
"language": "fa"
}
# Fall back to web search
web_results, web_docs = rag.crawl_web(query)
if web_results:
web_context, sources = rag.extract_context_from_web(web_results, web_docs, query)
return {
"has_context": True,
"context": web_context,
"source": "web",
"sources": sources,
"language": "fa"
}
# No context found
return {
"has_context": False,
"context": "متأسفانه اطلاعاتی در مورد سوال شما یافت نشد.",
"source": "none",
"language": "fa"
}
response = qa_system.get_answer(user_input)
print(f"\nBot: {response}")
elif args.query:
qa_system = AgenticQASystem()
print(qa_system.get_answer(args.query))
elif args.test:
print("Running tests...")

204
hybrid.py Normal file
View File

@ -0,0 +1,204 @@
import os
import nltk
from langchain_community.document_loaders import PDFPlumberLoader, WebBaseLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_core.prompts import ChatPromptTemplate
from langchain_ollama import OllamaEmbeddings, ChatOllama
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from typing_extensions import TypedDict
from langgraph.graph import START, END, StateGraph
# Ensure NLTK tokenizer is available
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download('punkt')
# Initialize model and embeddings
model = ChatOllama(model="gemma3:12b", temperature=0.2)
embeddings = OllamaEmbeddings(model="gemma3:12b")
# Vector store
vector_store = InMemoryVectorStore(embeddings)
# Templates
qa_template = """
You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question.
If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.
Question: {question}
Context: {context}
Answer:
"""
# Text splitter
def split_text(documents):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
add_start_index=True
)
return text_splitter.split_documents(documents)
# PDF handling
def load_pdf(file_path):
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
loader = PDFPlumberLoader(file_path)
documents = loader.load()
return documents
# Web page handling (using WebBaseLoader)
def load_webpage(url):
loader = WebBaseLoader(url)
documents = loader.load()
return documents
# Hybrid retriever
def build_hybrid_retriever(documents):
vector_store.clear()
vector_store.add_documents(documents)
semantic_retriever = vector_store.as_retriever(search_kwargs={"k": 3})
bm25_retriever = BM25Retriever.from_documents(documents)
bm25_retriever.k = 3
hybrid_retriever = EnsembleRetriever(
retrievers=[semantic_retriever, bm25_retriever],
weights=[0.7, 0.3]
)
return hybrid_retriever
# DuckDuckGo search implementation
def search_ddg(query, num_results=3):
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
search = DuckDuckGoSearchAPIWrapper()
results = search.results(query, num_results)
return results
# Answer question with error handling
def answer_question(question, documents):
try:
context = "\n\n".join([doc.page_content for doc in documents])
prompt = ChatPromptTemplate.from_template(qa_template)
chain = prompt | model
return chain.invoke({"question": question, "context": context}).content
except Exception as e:
return f"Error generating answer: {e}"
# Simple RAG node for web search
class WebSearchState(TypedDict):
query: str
results: list
response: str
def web_search(state):
results = search_ddg(state["query"])
return {"results": results}
def generate_search_response(state):
try:
context = "\n\n".join([f"{r['title']}: {r['snippet']}" for r in state["results"]])
prompt = ChatPromptTemplate.from_template(qa_template)
chain = prompt | model
response = chain.invoke({"question": state["query"], "context": context})
return {"response": response.content}
except Exception as e:
return {"response": f"Error generating response: {e}"}
# Build search graph
search_graph = StateGraph(WebSearchState)
search_graph.add_node("search", web_search)
search_graph.add_node("generate", generate_search_response)
search_graph.add_edge(START, "search")
search_graph.add_edge("search", "generate")
search_graph.add_edge("generate", END)
search_workflow = search_graph.compile()
# Main command-line interface
if __name__ == "__main__":
print("Welcome to the Advanced RAG System")
print("Choose an option:")
print("1. Analyze PDF")
print("2. Crawl URL")
print("3. Search Internet")
choice = input("Enter your choice (1/2/3): ")
if choice == "1":
pdf_path = input("Enter the path to the PDF file: ").strip()
if not pdf_path:
print("Please enter a valid file path.")
else:
try:
print("Processing PDF...")
documents = load_pdf(pdf_path)
if not documents:
print("No documents were loaded from the PDF. The file might be empty or not contain extractable text.")
else:
chunked_documents = split_text(documents)
if not chunked_documents:
print("No text chunks were created. The PDF might not contain any text.")
else:
retriever = build_hybrid_retriever(chunked_documents)
print(f"Processed {len(chunked_documents)} chunks")
question = input("Ask a question about the PDF: ").strip()
if not question:
print("Please enter a valid question.")
else:
print("Searching document...")
related_documents = retriever.get_relevant_documents(question)
if not related_documents:
print("No relevant documents found for the question.")
else:
answer = answer_question(question, related_documents)
print("Answer:", answer)
except Exception as e:
print(f"Error: {e}")
elif choice == "2":
url = input("Enter the URL to analyze: ").strip()
if not url:
print("Please enter a valid URL.")
else:
try:
print("Loading webpage...")
web_documents = load_webpage(url)
if not web_documents:
print("No documents were loaded from the webpage. The page might be empty or not contain extractable text.")
else:
web_chunks = split_text(web_documents)
if not web_chunks:
print("No text chunks were created. The webpage might not contain any text.")
else:
web_retriever = build_hybrid_retriever(web_chunks)
print(f"Processed {len(web_chunks)} chunks from webpage")
question = input("Ask a question about the webpage: ").strip()
if not question:
print("Please enter a valid question.")
else:
print("Analyzing content...")
web_results = web_retriever.get_relevant_documents(question)
if not web_results:
print("No relevant documents found for the question.")
else:
answer = answer_question(question, web_results)
print("Answer:", answer)
except Exception as e:
print(f"Error loading webpage: {e}")
elif choice == "3":
query = input("Enter your search query: ").strip()
if not query:
print("Please enter a valid search query.")
else:
try:
print("Searching the web...")
search_result = search_workflow.invoke({"query": query})
print("Response:", search_result["response"])
print("Sources:")
for result in search_result["results"]:
print(f"- {result['title']}: {result['link']}")
except Exception as e:
print(f"Error during search: {e}")
else:
print("Invalid choice")

198
memory.py Normal file
View File

@ -0,0 +1,198 @@
# --- Dependencies ---
# pip install langchain langchain-core langchain-ollama faiss-cpu sentence-transformers
import datetime
import os
from langchain_ollama import ChatOllama, OllamaEmbeddings
from langchain.memory import ConversationBufferMemory # Added for intra-session memory
from langchain_community.vectorstores import FAISS
from langchain.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableParallel
from langchain_core.output_parsers import StrOutputParser
from langchain.schema import Document # Needed for manual saving
# --- Config ---
FAISS_INDEX_PATH = "my_chatbot_memory_index" # Directory to save/load FAISS index
# --- Ollama LLM & Embeddings Setup ---
# Run in terminal: ollama pull gemma3
# Run in terminal: ollama pull nomic-embed-text
OLLAMA_LLM_MODEL = 'gemma3' # Using Gemma 3 as requested
OLLAMA_EMBED_MODEL = 'nomic-embed-text' # Recommended embedding model for Ollama
try:
llm = ChatOllama(model=OLLAMA_LLM_MODEL)
embeddings = OllamaEmbeddings(model=OLLAMA_EMBED_MODEL)
print(f"Successfully initialized Ollama: LLM='{OLLAMA_LLM_MODEL}', Embeddings='{OLLAMA_EMBED_MODEL}'")
# Optional tests removed for brevity
except Exception as e:
print(f"Error initializing Ollama components: {e}")
print(f"Ensure Ollama is running & models pulled (e.g., 'ollama pull {OLLAMA_LLM_MODEL}' and 'ollama pull {OLLAMA_EMBED_MODEL}').")
exit()
# --- Vector Store (Episodic Memory) Setup --- Persisted!
try:
if os.path.exists(FAISS_INDEX_PATH):
print(f"Loading existing FAISS index from: {FAISS_INDEX_PATH}")
vectorstore = FAISS.load_local(
FAISS_INDEX_PATH,
embeddings,
allow_dangerous_deserialization=True # Required for FAISS loading
)
retriever = vectorstore.as_retriever(search_kwargs=dict(k=3))
print("FAISS vector store loaded successfully.")
else:
print(f"No FAISS index found at {FAISS_INDEX_PATH}. Initializing new store.")
# FAISS needs at least one text to initialize.
vectorstore = FAISS.from_texts(
["Initial conversation context placeholder - Bot created"],
embeddings
)
retriever = vectorstore.as_retriever(search_kwargs=dict(k=3))
# Save the initial empty index
vectorstore.save_local(FAISS_INDEX_PATH)
print("New FAISS vector store initialized and saved.")
except Exception as e:
print(f"Error initializing/loading FAISS: {e}")
print("Check permissions or delete the index directory if corrupted.")
exit()
# --- Conversation Buffer (Short-Term) Memory Setup ---
# memory_key must match the input variable in the prompt
# return_messages=True formats history as suitable list of BaseMessages
buffer_memory = ConversationBufferMemory(
memory_key="chat_history",
return_messages=True
)
# <<< ADDED: Clear buffer at the start of each script run >>>
buffer_memory.clear()
# --- Define the Prompt Template ---
# Now includes chat_history for the buffer memory
template = """You are a helpful chatbot assistant with episodic memory (from past sessions) and conversational awareness (from the current session).
Use the following relevant pieces of information:
1. Episodic Memory (Knowledge from *previous* chat sessions):
{semantic_context}
2. Chat History (What we've discussed in the *current* session):
{chat_history}
Combine this information with the current user input to generate a coherent and contextually relevant answer.
If recalling information from Episodic Memory, you can mention it stems from a past conversation if appropriate.
If no relevant context or history is found, just respond naturally to the current input.
Current Input:
User: {input}
Assistant:"""
prompt = PromptTemplate(
input_variables=["semantic_context", "chat_history", "input"],
template=template
)
# --- Helper Function for Formatting Retrieved Docs (Episodic Memory) ---
# Formats the retrieved documents (past interactions) for the prompt
def format_retrieved_docs(docs):
# Simplified formatting: Extract core content only and label explicitly
formatted = []
for doc in docs:
content = doc.page_content
# Basic check to remove placeholder
if content not in ["Initial conversation context placeholder - Bot created"]:
# Attempt to strip "Role (timestamp): " prefix if present
if "):":
content = content.split("):", 1)[-1].strip()
if content: # Ensure content is not empty after stripping
formatted.append(f"Recalled from a past session: {content}")
# Use a double newline to separate recalled memories clearly
return "\n\n".join(formatted) if formatted else "No relevant memories found from past sessions."
# --- Chain Definition using LCEL ---
# Function to load episodic memory (FAISS context)
def load_episodic_memory(input_dict):
query = input_dict.get("input", "")
docs = retriever.invoke(query)
return format_retrieved_docs(docs)
# Function to save episodic memory (and persist FAISS index)
def save_episodic_memory_step(inputs_outputs):
user_input = inputs_outputs.get("input", "")
llm_output = inputs_outputs.get("output", "")
if user_input and llm_output:
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
docs_to_add = [
Document(page_content=f"User ({timestamp}): {user_input}"),
Document(page_content=f"Assistant ({timestamp}): {llm_output}")
]
vectorstore.add_documents(docs_to_add)
vectorstore.save_local(FAISS_INDEX_PATH) # Persist index after adding
# print(f"DEBUG: Saved to FAISS index: {FAISS_INDEX_PATH}")
return inputs_outputs # Pass the dict through for potential further steps
# Define the core chain logic
chain_core = (
RunnablePassthrough.assign(
semantic_context=RunnableLambda(load_episodic_memory),
chat_history=RunnableLambda(lambda x: buffer_memory.load_memory_variables(x)['chat_history'])
)
| prompt
| llm
| StrOutputParser()
)
# Wrap the core logic to handle memory updates
def run_chain(input_dict):
user_input = input_dict['input']
# Invoke the core chain to get the response
llm_response = chain_core.invoke({"input": user_input})
# Prepare data for saving
save_data = {"input": user_input, "output": llm_response}
# Save to episodic memory (FAISS)
save_episodic_memory_step(save_data)
# Save to buffer memory
buffer_memory.save_context({"input": user_input}, {"output": llm_response})
return llm_response
# --- Chat Loop ---
print(f"\nChatbot Ready! Using Ollama ('{OLLAMA_LLM_MODEL}' chat, '{OLLAMA_EMBED_MODEL}' embed)")
print(f"Episodic memory stored in: {FAISS_INDEX_PATH}")
print("Type 'quit', 'exit', or 'bye' to end the conversation.")
while True:
user_text = input("You: ")
if user_text.lower() in ["quit", "exit", "bye"]:
# Optionally clear buffer memory on exit if desired
buffer_memory.clear()
print("Chatbot: Goodbye!")
break
if not user_text:
continue
try:
# Use the wrapper function to handle the chain invocation and memory updates
response = run_chain({"input": user_text})
print(f"Chatbot: {response}")
# Optional debug: View buffer memory
# print("DEBUG: Buffer Memory:", buffer_memory.load_memory_variables({}))
# Optional debug: Check vector store size
# print(f"DEBUG: Vector store size: {vectorstore.index.ntotal}")
except Exception as e:
print(f"\nAn error occurred during the chat chain: {e}")
# Add more detailed error logging if needed
import traceback
print(traceback.format_exc())
# --- End of Script ---

1
req.txt Normal file
View File

@ -0,0 +1 @@
pip install streamlit langchain-core langchain-ollama unstructured[pdf] langchain-text-splitters pypdf