267 lines
9.4 KiB
Python
267 lines
9.4 KiB
Python
import os
|
|
import subprocess
|
|
|
|
# Clear proxy settings
|
|
def clear_proxy_settings():
|
|
for var in ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]:
|
|
if var in os.environ:
|
|
del os.environ[var]
|
|
|
|
clear_proxy_settings()
|
|
|
|
import os
|
|
import tempfile
|
|
import subprocess
|
|
from datetime import datetime
|
|
|
|
import streamlit as st
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
from langchain_core.vectorstores import InMemoryVectorStore
|
|
from langchain_ollama import OllamaEmbeddings
|
|
from langchain_ollama.llms import OllamaLLM
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
from unstructured.partition.pdf import partition_pdf
|
|
from unstructured.partition.utils.constants import PartitionStrategy
|
|
from search_utils import duckduckgo_search, rank_results
|
|
|
|
template = """
|
|
تو یک دستیار هستی که از یک داده های متنی و تصویری استفاده میکنی تا به سوالات کاربر به زبان فارسی سلیس پاسخ بدی.
|
|
Question: {question}
|
|
Context: {context}
|
|
Answer:
|
|
"""
|
|
|
|
pdfs_directory = 'multi-modal-rag/pdfs/'
|
|
figures_directory = 'multi-modal-rag/figures/'
|
|
images_directory = 'multi-modal-rag/images/'
|
|
videos_directory = 'multi-modal-rag/videos/'
|
|
audio_directory = 'multi-modal-rag/audio/'
|
|
frames_directory = 'multi-modal-rag/frames/'
|
|
|
|
# Create directories if they don't exist
|
|
os.makedirs(pdfs_directory, exist_ok=True)
|
|
os.makedirs(figures_directory, exist_ok=True)
|
|
os.makedirs(images_directory, exist_ok=True)
|
|
os.makedirs(videos_directory, exist_ok=True)
|
|
os.makedirs(audio_directory, exist_ok=True)
|
|
os.makedirs(frames_directory, exist_ok=True)
|
|
|
|
embeddings = OllamaEmbeddings(model="llama3.2")
|
|
vector_store = InMemoryVectorStore(embeddings)
|
|
|
|
model = OllamaLLM(model="gemma3")
|
|
|
|
def upload_pdf(file):
|
|
with open(pdfs_directory + file.name, "wb") as f:
|
|
f.write(file.getbuffer())
|
|
|
|
def upload_image(file):
|
|
with open(images_directory + file.name, "wb") as f:
|
|
f.write(file.getbuffer())
|
|
return images_directory + file.name
|
|
|
|
def upload_video(file):
|
|
file_path = videos_directory + file.name
|
|
with open(file_path, "wb") as f:
|
|
f.write(file.getbuffer())
|
|
return file_path
|
|
|
|
def upload_audio(file):
|
|
file_path = audio_directory + file.name
|
|
with open(file_path, "wb") as f:
|
|
f.write(file.getbuffer())
|
|
return file_path
|
|
|
|
def load_pdf(file_path):
|
|
elements = partition_pdf(
|
|
file_path,
|
|
strategy=PartitionStrategy.HI_RES,
|
|
extract_image_block_types=["Image", "Table"],
|
|
extract_image_block_output_dir=figures_directory
|
|
)
|
|
|
|
text_elements = [element.text for element in elements if element.category not in ["Image", "Table"]]
|
|
|
|
for file in os.listdir(figures_directory):
|
|
extracted_text = extract_text(figures_directory + file)
|
|
text_elements.append(extracted_text)
|
|
|
|
return "\n\n".join(text_elements)
|
|
|
|
def extract_frames(video_path, num_frames=5):
|
|
"""Extract frames from video file and save them to frames directory"""
|
|
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
|
base_name = os.path.basename(video_path).split('.')[0]
|
|
frame_paths = []
|
|
|
|
# Extract frames using ffmpeg
|
|
for i in range(num_frames):
|
|
frame_path = f"{frames_directory}{base_name}_{timestamp}_{i}.jpg"
|
|
cmd = [
|
|
'ffmpeg', '-i', video_path,
|
|
'-ss', str(i * (1/num_frames)), '-vframes', '1',
|
|
'-q:v', '2', frame_path, '-y'
|
|
]
|
|
try:
|
|
subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
|
frame_paths.append(frame_path)
|
|
except subprocess.CalledProcessError:
|
|
st.warning(f"Failed to extract frame {i} from video")
|
|
|
|
return frame_paths
|
|
|
|
def process_audio(audio_path):
|
|
"""Process audio file using the model"""
|
|
audio_description = model.invoke(
|
|
f"Describe what you hear in this audio file: {os.path.basename(audio_path)}"
|
|
)
|
|
return f"Audio file: {os.path.basename(audio_path)}. Description: {audio_description}"
|
|
|
|
def extract_text(file_path):
|
|
model_with_image_context = model.bind(images=[file_path])
|
|
return model_with_image_context.invoke("Tell me what do you see in this picture.")
|
|
|
|
def split_text(text):
|
|
text_splitter = RecursiveCharacterTextSplitter(
|
|
chunk_size=1000,
|
|
chunk_overlap=200,
|
|
add_start_index=True
|
|
)
|
|
|
|
return text_splitter.split_text(text)
|
|
|
|
def index_docs(texts):
|
|
vector_store.add_texts(texts)
|
|
|
|
def retrieve_docs(query):
|
|
return vector_store.similarity_search(query)
|
|
|
|
def answer_question(question, documents):
|
|
local_context = "\n\n".join([doc.page_content for doc in documents])
|
|
prompt = ChatPromptTemplate.from_template(template)
|
|
chain = prompt | model
|
|
return chain.invoke({"question": question, "context": local_context})
|
|
|
|
# Sidebar for upload options
|
|
st.sidebar.title("Upload Documents")
|
|
upload_option = st.sidebar.radio("Choose upload type:", ["PDF", "Image", "Video", "Audio", "Search"])
|
|
|
|
if upload_option == "Search":
|
|
st.title("Web Search with BM25 Ranking")
|
|
search_query = st.text_input("Enter your search query:")
|
|
|
|
if search_query:
|
|
with st.spinner("Searching and ranking results..."):
|
|
# Get search results
|
|
search_results = duckduckgo_search(search_query, max_results=10)
|
|
|
|
if search_results:
|
|
# Rank results using BM25
|
|
ranked_results = rank_results(search_query, search_results)
|
|
|
|
# Display results
|
|
st.subheader("Ranked Search Results")
|
|
for i, result in enumerate(ranked_results):
|
|
with st.expander(f"{i+1}. {result.title}"):
|
|
st.write(f"**Snippet:** {result.snippet}")
|
|
st.write(f"**URL:** {result.url}")
|
|
|
|
# Option to ask about search results
|
|
st.subheader("Ask about these results")
|
|
question = st.text_input("Enter your question about the search results:")
|
|
|
|
if question:
|
|
# Prepare context from top results
|
|
context = "\n\n".join([f"Title: {r.title}\nSnippet: {r.snippet}" for r in ranked_results[:3]])
|
|
|
|
# Use the model to answer
|
|
prompt = ChatPromptTemplate.from_template(template)
|
|
chain = prompt | model
|
|
|
|
with st.spinner("Generating answer..."):
|
|
response = chain.invoke({"question": question, "context": context})
|
|
st.markdown("### Answer")
|
|
st.write(response.content)
|
|
else:
|
|
st.warning("No search results found")
|
|
|
|
elif upload_option == "PDF":
|
|
uploaded_file = st.file_uploader(
|
|
"Upload PDF",
|
|
type="pdf",
|
|
accept_multiple_files=False
|
|
)
|
|
|
|
if uploaded_file:
|
|
upload_pdf(uploaded_file)
|
|
with st.spinner("Processing PDF..."):
|
|
text = load_pdf(pdfs_directory + uploaded_file.name)
|
|
chunked_texts = split_text(text)
|
|
index_docs(chunked_texts)
|
|
st.success("PDF processed successfully!")
|
|
|
|
elif upload_option == "Image":
|
|
uploaded_image = st.file_uploader(
|
|
"Upload Image",
|
|
type=["jpg", "jpeg", "png"],
|
|
accept_multiple_files=False
|
|
)
|
|
|
|
if uploaded_image:
|
|
image_path = upload_image(uploaded_image)
|
|
st.image(image_path, caption="Uploaded Image", use_column_width=True)
|
|
with st.spinner("Processing image..."):
|
|
image_description = extract_text(image_path)
|
|
index_docs([image_description])
|
|
st.success("Image processed and added to knowledge base")
|
|
|
|
elif upload_option == "Video":
|
|
uploaded_video = st.file_uploader(
|
|
"Upload Video",
|
|
type=["mp4", "avi", "mov", "mkv"],
|
|
accept_multiple_files=False
|
|
)
|
|
|
|
if uploaded_video:
|
|
video_path = upload_video(uploaded_video)
|
|
st.video(video_path)
|
|
|
|
with st.spinner("Processing video frames..."):
|
|
frame_paths = extract_frames(video_path)
|
|
video_descriptions = []
|
|
|
|
for frame_path in frame_paths:
|
|
st.image(frame_path, caption=f"Frame from video", width=200)
|
|
frame_description = extract_text(frame_path)
|
|
video_descriptions.append(frame_description)
|
|
|
|
# Add a combined description
|
|
combined_description = f"Video file: {uploaded_video.name}. Content description: " + " ".join(video_descriptions)
|
|
index_docs([combined_description])
|
|
st.success("Video processed and added to knowledge base")
|
|
|
|
else: # Audio option
|
|
uploaded_audio = st.file_uploader(
|
|
"Upload Audio",
|
|
type=["mp3", "wav", "ogg"],
|
|
accept_multiple_files=False
|
|
)
|
|
|
|
if uploaded_audio:
|
|
audio_path = upload_audio(uploaded_audio)
|
|
st.audio(audio_path)
|
|
|
|
with st.spinner("Processing audio..."):
|
|
# For audio, we'll use the model directly without visual context
|
|
audio_description = process_audio(audio_path)
|
|
index_docs([audio_description])
|
|
st.success("Audio processed and added to knowledge base")
|
|
|
|
# Chat interface
|
|
question = st.chat_input()
|
|
|
|
if question:
|
|
st.chat_message("user").write(question)
|
|
related_documents = retrieve_docs(question)
|
|
answer = answer_question(question, related_documents)
|
|
st.chat_message("assistant").write(answer) |