204 lines
7.9 KiB
Python
204 lines
7.9 KiB
Python
import os
|
|
import nltk
|
|
from langchain_community.document_loaders import PDFPlumberLoader, WebBaseLoader
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
from langchain_core.vectorstores import InMemoryVectorStore
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
from langchain_ollama import OllamaEmbeddings, ChatOllama
|
|
from langchain_community.retrievers import BM25Retriever
|
|
from langchain.retrievers import EnsembleRetriever
|
|
from typing_extensions import TypedDict
|
|
from langgraph.graph import START, END, StateGraph
|
|
|
|
# Ensure NLTK tokenizer is available
|
|
try:
|
|
nltk.data.find('tokenizers/punkt')
|
|
except LookupError:
|
|
nltk.download('punkt')
|
|
|
|
# Initialize model and embeddings
|
|
model = ChatOllama(model="gemma3:12b", temperature=0.2)
|
|
embeddings = OllamaEmbeddings(model="gemma3:12b")
|
|
|
|
# Vector store
|
|
vector_store = InMemoryVectorStore(embeddings)
|
|
|
|
# Templates
|
|
qa_template = """
|
|
You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question.
|
|
If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.
|
|
Question: {question}
|
|
Context: {context}
|
|
Answer:
|
|
"""
|
|
|
|
# Text splitter
|
|
def split_text(documents):
|
|
text_splitter = RecursiveCharacterTextSplitter(
|
|
chunk_size=1000,
|
|
chunk_overlap=200,
|
|
add_start_index=True
|
|
)
|
|
return text_splitter.split_documents(documents)
|
|
|
|
# PDF handling
|
|
def load_pdf(file_path):
|
|
if not os.path.exists(file_path):
|
|
raise FileNotFoundError(f"File not found: {file_path}")
|
|
loader = PDFPlumberLoader(file_path)
|
|
documents = loader.load()
|
|
return documents
|
|
|
|
# Web page handling (using WebBaseLoader)
|
|
def load_webpage(url):
|
|
loader = WebBaseLoader(url)
|
|
documents = loader.load()
|
|
return documents
|
|
|
|
# Hybrid retriever
|
|
def build_hybrid_retriever(documents):
|
|
vector_store.clear()
|
|
vector_store.add_documents(documents)
|
|
semantic_retriever = vector_store.as_retriever(search_kwargs={"k": 3})
|
|
bm25_retriever = BM25Retriever.from_documents(documents)
|
|
bm25_retriever.k = 3
|
|
hybrid_retriever = EnsembleRetriever(
|
|
retrievers=[semantic_retriever, bm25_retriever],
|
|
weights=[0.7, 0.3]
|
|
)
|
|
return hybrid_retriever
|
|
|
|
# DuckDuckGo search implementation
|
|
def search_ddg(query, num_results=3):
|
|
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
|
|
search = DuckDuckGoSearchAPIWrapper()
|
|
results = search.results(query, num_results)
|
|
return results
|
|
|
|
# Answer question with error handling
|
|
def answer_question(question, documents):
|
|
try:
|
|
context = "\n\n".join([doc.page_content for doc in documents])
|
|
prompt = ChatPromptTemplate.from_template(qa_template)
|
|
chain = prompt | model
|
|
return chain.invoke({"question": question, "context": context}).content
|
|
except Exception as e:
|
|
return f"Error generating answer: {e}"
|
|
|
|
# Simple RAG node for web search
|
|
class WebSearchState(TypedDict):
|
|
query: str
|
|
results: list
|
|
response: str
|
|
|
|
def web_search(state):
|
|
results = search_ddg(state["query"])
|
|
return {"results": results}
|
|
|
|
def generate_search_response(state):
|
|
try:
|
|
context = "\n\n".join([f"{r['title']}: {r['snippet']}" for r in state["results"]])
|
|
prompt = ChatPromptTemplate.from_template(qa_template)
|
|
chain = prompt | model
|
|
response = chain.invoke({"question": state["query"], "context": context})
|
|
return {"response": response.content}
|
|
except Exception as e:
|
|
return {"response": f"Error generating response: {e}"}
|
|
|
|
# Build search graph
|
|
search_graph = StateGraph(WebSearchState)
|
|
search_graph.add_node("search", web_search)
|
|
search_graph.add_node("generate", generate_search_response)
|
|
search_graph.add_edge(START, "search")
|
|
search_graph.add_edge("search", "generate")
|
|
search_graph.add_edge("generate", END)
|
|
search_workflow = search_graph.compile()
|
|
|
|
# Main command-line interface
|
|
if __name__ == "__main__":
|
|
print("Welcome to the Advanced RAG System")
|
|
print("Choose an option:")
|
|
print("1. Analyze PDF")
|
|
print("2. Crawl URL")
|
|
print("3. Search Internet")
|
|
choice = input("Enter your choice (1/2/3): ")
|
|
|
|
if choice == "1":
|
|
pdf_path = input("Enter the path to the PDF file: ").strip()
|
|
if not pdf_path:
|
|
print("Please enter a valid file path.")
|
|
else:
|
|
try:
|
|
print("Processing PDF...")
|
|
documents = load_pdf(pdf_path)
|
|
if not documents:
|
|
print("No documents were loaded from the PDF. The file might be empty or not contain extractable text.")
|
|
else:
|
|
chunked_documents = split_text(documents)
|
|
if not chunked_documents:
|
|
print("No text chunks were created. The PDF might not contain any text.")
|
|
else:
|
|
retriever = build_hybrid_retriever(chunked_documents)
|
|
print(f"Processed {len(chunked_documents)} chunks")
|
|
question = input("Ask a question about the PDF: ").strip()
|
|
if not question:
|
|
print("Please enter a valid question.")
|
|
else:
|
|
print("Searching document...")
|
|
related_documents = retriever.get_relevant_documents(question)
|
|
if not related_documents:
|
|
print("No relevant documents found for the question.")
|
|
else:
|
|
answer = answer_question(question, related_documents)
|
|
print("Answer:", answer)
|
|
except Exception as e:
|
|
print(f"Error: {e}")
|
|
|
|
elif choice == "2":
|
|
url = input("Enter the URL to analyze: ").strip()
|
|
if not url:
|
|
print("Please enter a valid URL.")
|
|
else:
|
|
try:
|
|
print("Loading webpage...")
|
|
web_documents = load_webpage(url)
|
|
if not web_documents:
|
|
print("No documents were loaded from the webpage. The page might be empty or not contain extractable text.")
|
|
else:
|
|
web_chunks = split_text(web_documents)
|
|
if not web_chunks:
|
|
print("No text chunks were created. The webpage might not contain any text.")
|
|
else:
|
|
web_retriever = build_hybrid_retriever(web_chunks)
|
|
print(f"Processed {len(web_chunks)} chunks from webpage")
|
|
question = input("Ask a question about the webpage: ").strip()
|
|
if not question:
|
|
print("Please enter a valid question.")
|
|
else:
|
|
print("Analyzing content...")
|
|
web_results = web_retriever.get_relevant_documents(question)
|
|
if not web_results:
|
|
print("No relevant documents found for the question.")
|
|
else:
|
|
answer = answer_question(question, web_results)
|
|
print("Answer:", answer)
|
|
except Exception as e:
|
|
print(f"Error loading webpage: {e}")
|
|
|
|
elif choice == "3":
|
|
query = input("Enter your search query: ").strip()
|
|
if not query:
|
|
print("Please enter a valid search query.")
|
|
else:
|
|
try:
|
|
print("Searching the web...")
|
|
search_result = search_workflow.invoke({"query": query})
|
|
print("Response:", search_result["response"])
|
|
print("Sources:")
|
|
for result in search_result["results"]:
|
|
print(f"- {result['title']}: {result['link']}")
|
|
except Exception as e:
|
|
print(f"Error during search: {e}")
|
|
|
|
else:
|
|
print("Invalid choice") |