My AI Journey Day 2: Playing with LangChain RAG
Experimenting with LangChain based RAG:
Steps which need to be executed
- Initializing OllamaEmbeddings to use local Ollama Embedding Model
- Initializing OllamaLLM to use local Ollama LLM Model
- Initializing ChromaDB as a local vector store
- Using PyPDFLoader for extracting the PDF document
- Load and split the PDF document
- Initialize RecursiveCharacterTextSplitter for text chunking
- Chunk each extracted page from the pdf document and list with generate UUID for each chunk in the list
- Add only the new chunked documents in the local Vector Store
- Execute the Similarity Search Query on the Vector Store
- Initialize ChatPromptTemplate with predefined Prompt
- Create new Prompt from the ChatPromptTemplate and the search context from the Vector Store
- Execute LLM search
#!/usr/bin/env python3
import asyncio
from importlib.metadata import metadata
from langchain_community.document_loaders import PyPDFLoader
from langchain_ollama import OllamaEmbeddings
from langchain_chroma import Chroma
from langchain_core.documents import Document
from langchain.prompts import ChatPromptTemplate
from langchain_ollama import OllamaLLM
from langchain_text_splitters import RecursiveCharacterTextSplitter
from uuid import uuid4
class FAQChat:
def __init__(self, db_path = "./", collection_name = "example_collection", embedding_llm = "mxbai-embed-large:latest", generating_llm = "tinyllama:latest", pdf_file = "") -> None:
self.db_path = db_path
self.collection_name = collection_name
self.embedding_llm = embedding_llm
self.embeddings = OllamaEmbeddings(model=self.embedding_llm)
self.pdf_file = pdf_file
self.model = OllamaLLM(model=generating_llm)
self.vector_store = Chroma(
collection_name=self.collection_name,
embedding_function=self.embeddings,
persist_directory=self.db_path,
)
self.prompt_template = """
Use the following pieces of context to answer the question at the end.
If you don't know the answer, just say that you are unsure.
Don't try to make up an answer.
{context}
Question: {question}
Answer:
"""
async def pdf_extract_and_vectorize(self):
loader = PyPDFLoader(self.pdf_file)
pages = []
documents_data = []
async for page in loader.alazy_load():
pages.append(page)
for i in range(len(pages)):
single_page_data = pages[i].page_content
document_page = Document(
page_content=single_page_data,
metadata={"source": "faq"},
id=i,
)
documents_data.append(document_page)
print(single_page_data)
uuids = [str(uuid4()) for _ in range(len(documents_data))]
self.vector_store.add_documents(documents=documents_data, ids=uuids)
async def pdf_extract_and_vectorize_chunks(self):
loader = PyPDFLoader(self.pdf_file)
pages = loader.load()
chunksSplitter = RecursiveCharacterTextSplitter(chunk_size=400, chunk_overlap=200, length_function=len, is_separator_regex=False)
pageChunks = chunksSplitter.split_documents(pages)
for i in pageChunks:
i.metadata["chunk_id"] = str(uuid4())
chunk_data = self.vector_store.get()
present_ids = chunk_data["ids"]
new_chunks = [i for i in pageChunks if i.metadata.get("chunk_id") not in present_ids]
if len(new_chunks) > 0:
self.vector_store.add_documents(new_chunks, ids = [i.metadata["chunk_id"] for i in new_chunks])
else:
print("Nothing to persist")
async def chat_thru_vectors(self, query_text = "Who has written the programming language Python?"):
context = self.vector_store.similarity_search_with_score(query_text, k=1)
context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in context])
prompt_template = ChatPromptTemplate.from_template(self.prompt_template)
prompt = prompt_template.format(context=context_text, question=query_text)
response_text = self.model.invoke(prompt)
print(response_text)
def main():
faqChat = FAQChat(
"./chroma_langchain_db",
"example_collection",
"mxbai-embed-large:latest",
"tinyllama:latest",
"./Ethernet_FAQ.pdf"
)
loop = asyncio.get_event_loop()
# loop.run_until_complete(faqChat.pdf_extract_and_vectorize())
print("Vectorizing")
loop.run_until_complete(faqChat.pdf_extract_and_vectorize_chunks())
print("Retrieving")
loop.run_until_complete(faqChat.chat_thru_vectors("What is a network heartbeat, give me a fast, simple and short answer?"))
loop.close()
if __name__ == '__main__':
main()