198 lines
7.9 KiB
Python
198 lines
7.9 KiB
Python
# --- Dependencies ---
|
|
# pip install langchain langchain-core langchain-ollama faiss-cpu sentence-transformers
|
|
|
|
import datetime
|
|
import os
|
|
from langchain_ollama import ChatOllama, OllamaEmbeddings
|
|
from langchain.memory import ConversationBufferMemory # Added for intra-session memory
|
|
from langchain_community.vectorstores import FAISS
|
|
from langchain.prompts import PromptTemplate
|
|
from langchain_core.runnables import RunnablePassthrough, RunnableLambda, RunnableParallel
|
|
from langchain_core.output_parsers import StrOutputParser
|
|
from langchain.schema import Document # Needed for manual saving
|
|
|
|
# --- Config ---
|
|
FAISS_INDEX_PATH = "my_chatbot_memory_index" # Directory to save/load FAISS index
|
|
|
|
# --- Ollama LLM & Embeddings Setup ---
|
|
# Run in terminal: ollama pull gemma3
|
|
# Run in terminal: ollama pull nomic-embed-text
|
|
OLLAMA_LLM_MODEL = 'gemma3' # Using Gemma 3 as requested
|
|
OLLAMA_EMBED_MODEL = 'nomic-embed-text' # Recommended embedding model for Ollama
|
|
|
|
try:
|
|
llm = ChatOllama(model=OLLAMA_LLM_MODEL)
|
|
embeddings = OllamaEmbeddings(model=OLLAMA_EMBED_MODEL)
|
|
print(f"Successfully initialized Ollama: LLM='{OLLAMA_LLM_MODEL}', Embeddings='{OLLAMA_EMBED_MODEL}'")
|
|
# Optional tests removed for brevity
|
|
except Exception as e:
|
|
print(f"Error initializing Ollama components: {e}")
|
|
print(f"Ensure Ollama is running & models pulled (e.g., 'ollama pull {OLLAMA_LLM_MODEL}' and 'ollama pull {OLLAMA_EMBED_MODEL}').")
|
|
exit()
|
|
|
|
# --- Vector Store (Episodic Memory) Setup --- Persisted!
|
|
try:
|
|
if os.path.exists(FAISS_INDEX_PATH):
|
|
print(f"Loading existing FAISS index from: {FAISS_INDEX_PATH}")
|
|
vectorstore = FAISS.load_local(
|
|
FAISS_INDEX_PATH,
|
|
embeddings,
|
|
allow_dangerous_deserialization=True # Required for FAISS loading
|
|
)
|
|
retriever = vectorstore.as_retriever(search_kwargs=dict(k=3))
|
|
print("FAISS vector store loaded successfully.")
|
|
else:
|
|
print(f"No FAISS index found at {FAISS_INDEX_PATH}. Initializing new store.")
|
|
# FAISS needs at least one text to initialize.
|
|
vectorstore = FAISS.from_texts(
|
|
["Initial conversation context placeholder - Bot created"],
|
|
embeddings
|
|
)
|
|
retriever = vectorstore.as_retriever(search_kwargs=dict(k=3))
|
|
# Save the initial empty index
|
|
vectorstore.save_local(FAISS_INDEX_PATH)
|
|
print("New FAISS vector store initialized and saved.")
|
|
|
|
except Exception as e:
|
|
print(f"Error initializing/loading FAISS: {e}")
|
|
print("Check permissions or delete the index directory if corrupted.")
|
|
exit()
|
|
|
|
# --- Conversation Buffer (Short-Term) Memory Setup ---
|
|
# memory_key must match the input variable in the prompt
|
|
# return_messages=True formats history as suitable list of BaseMessages
|
|
buffer_memory = ConversationBufferMemory(
|
|
memory_key="chat_history",
|
|
return_messages=True
|
|
)
|
|
# <<< ADDED: Clear buffer at the start of each script run >>>
|
|
buffer_memory.clear()
|
|
|
|
# --- Define the Prompt Template ---
|
|
# Now includes chat_history for the buffer memory
|
|
template = """You are a helpful chatbot assistant with episodic memory (from past sessions) and conversational awareness (from the current session).
|
|
Use the following relevant pieces of information:
|
|
1. Episodic Memory (Knowledge from *previous* chat sessions):
|
|
{semantic_context}
|
|
|
|
2. Chat History (What we've discussed in the *current* session):
|
|
{chat_history}
|
|
|
|
Combine this information with the current user input to generate a coherent and contextually relevant answer.
|
|
If recalling information from Episodic Memory, you can mention it stems from a past conversation if appropriate.
|
|
If no relevant context or history is found, just respond naturally to the current input.
|
|
|
|
Current Input:
|
|
User: {input}
|
|
Assistant:"""
|
|
|
|
prompt = PromptTemplate(
|
|
input_variables=["semantic_context", "chat_history", "input"],
|
|
template=template
|
|
)
|
|
|
|
# --- Helper Function for Formatting Retrieved Docs (Episodic Memory) ---
|
|
# Formats the retrieved documents (past interactions) for the prompt
|
|
def format_retrieved_docs(docs):
|
|
# Simplified formatting: Extract core content only and label explicitly
|
|
formatted = []
|
|
for doc in docs:
|
|
content = doc.page_content
|
|
# Basic check to remove placeholder
|
|
if content not in ["Initial conversation context placeholder - Bot created"]:
|
|
# Attempt to strip "Role (timestamp): " prefix if present
|
|
if "):":
|
|
content = content.split("):", 1)[-1].strip()
|
|
if content: # Ensure content is not empty after stripping
|
|
formatted.append(f"Recalled from a past session: {content}")
|
|
# Use a double newline to separate recalled memories clearly
|
|
return "\n\n".join(formatted) if formatted else "No relevant memories found from past sessions."
|
|
|
|
|
|
# --- Chain Definition using LCEL ---
|
|
|
|
# Function to load episodic memory (FAISS context)
|
|
def load_episodic_memory(input_dict):
|
|
query = input_dict.get("input", "")
|
|
docs = retriever.invoke(query)
|
|
return format_retrieved_docs(docs)
|
|
|
|
# Function to save episodic memory (and persist FAISS index)
|
|
def save_episodic_memory_step(inputs_outputs):
|
|
user_input = inputs_outputs.get("input", "")
|
|
llm_output = inputs_outputs.get("output", "")
|
|
|
|
if user_input and llm_output:
|
|
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
docs_to_add = [
|
|
Document(page_content=f"User ({timestamp}): {user_input}"),
|
|
Document(page_content=f"Assistant ({timestamp}): {llm_output}")
|
|
]
|
|
vectorstore.add_documents(docs_to_add)
|
|
vectorstore.save_local(FAISS_INDEX_PATH) # Persist index after adding
|
|
# print(f"DEBUG: Saved to FAISS index: {FAISS_INDEX_PATH}")
|
|
return inputs_outputs # Pass the dict through for potential further steps
|
|
|
|
|
|
# Define the core chain logic
|
|
chain_core = (
|
|
RunnablePassthrough.assign(
|
|
semantic_context=RunnableLambda(load_episodic_memory),
|
|
chat_history=RunnableLambda(lambda x: buffer_memory.load_memory_variables(x)['chat_history'])
|
|
)
|
|
| prompt
|
|
| llm
|
|
| StrOutputParser()
|
|
)
|
|
|
|
# Wrap the core logic to handle memory updates
|
|
def run_chain(input_dict):
|
|
user_input = input_dict['input']
|
|
|
|
# Invoke the core chain to get the response
|
|
llm_response = chain_core.invoke({"input": user_input})
|
|
|
|
# Prepare data for saving
|
|
save_data = {"input": user_input, "output": llm_response}
|
|
|
|
# Save to episodic memory (FAISS)
|
|
save_episodic_memory_step(save_data)
|
|
|
|
# Save to buffer memory
|
|
buffer_memory.save_context({"input": user_input}, {"output": llm_response})
|
|
|
|
return llm_response
|
|
|
|
|
|
# --- Chat Loop ---
|
|
print(f"\nChatbot Ready! Using Ollama ('{OLLAMA_LLM_MODEL}' chat, '{OLLAMA_EMBED_MODEL}' embed)")
|
|
print(f"Episodic memory stored in: {FAISS_INDEX_PATH}")
|
|
print("Type 'quit', 'exit', or 'bye' to end the conversation.")
|
|
|
|
while True:
|
|
user_text = input("You: ")
|
|
if user_text.lower() in ["quit", "exit", "bye"]:
|
|
# Optionally clear buffer memory on exit if desired
|
|
buffer_memory.clear()
|
|
print("Chatbot: Goodbye!")
|
|
break
|
|
if not user_text:
|
|
continue
|
|
|
|
try:
|
|
# Use the wrapper function to handle the chain invocation and memory updates
|
|
response = run_chain({"input": user_text})
|
|
print(f"Chatbot: {response}")
|
|
|
|
# Optional debug: View buffer memory
|
|
# print("DEBUG: Buffer Memory:", buffer_memory.load_memory_variables({}))
|
|
# Optional debug: Check vector store size
|
|
# print(f"DEBUG: Vector store size: {vectorstore.index.ntotal}")
|
|
|
|
except Exception as e:
|
|
print(f"\nAn error occurred during the chat chain: {e}")
|
|
# Add more detailed error logging if needed
|
|
import traceback
|
|
print(traceback.format_exc())
|
|
|
|
# --- End of Script --- |