Context-Enriched RAG for More Coherent Responses
Context-Enriched Retrieval in RAG
Retrieval-Augmented Generation (RAG) enhances AI responses by retrieving relevant knowledge from external sources. Traditional retrieval methods return isolated text chunks, which can lead to incomplete answers.
To address this, we introduce Context-Enriched Retrieval, which ensures that retrieved information includes neighboring chunks for better coherence.
Steps in This Notebook:
- Data Ingestion: Extract text from a PDF.
- Chunking with Overlapping Context: Split text into overlapping chunks to preserve context.
- Embedding Creation: Convert text chunks into numerical representations.
- Context-Aware Retrieval: Retrieve relevant chunks along with their neighbors for better completeness.
- Response Generation: Use a language model to generate responses based on retrieved context.
- Evaluation: Assess the model's response accuracy.
Setting Up the Environment
We begin by importing necessary libraries.
import fitz
import os
import numpy as np
import json
from openai import OpenAI
Extracting Text from a PDF File
To implement RAG, we first need a source of textual data. In this case, we extract text from a PDF file using the PyMuPDF library.
def extract_text_from_pdf(pdf_path):
mypdf = fitz.open(pdf_path)
all_text = ""
for page_num in range(mypdf.page_count):
page = mypdf[page_num]
text = page.get_text("text")
all_text += text
return all_text
Chunking the Extracted Text
Once we have the extracted text, we divide it into smaller, overlapping chunks to improve retrieval accuracy.
def chunk_text(text, n, overlap):
chunks = []
for i in range(0, len(text), n - overlap):
chunks.append(text[i:i + n])
return chunks
Setting Up the OpenAI API Client
We initialize the OpenAI client to generate embeddings and responses.
client = OpenAI(
base_url="https://api.studio.nebius.com/v1/",
api_key=os.getenv("OPENAI_API_KEY")
)
Extracting and Chunking Text from a PDF File
Now, we load the PDF, extract text, and split it into chunks.
pdf_path = "data/AI_Information.pdf"
extracted_text = extract_text_from_pdf(pdf_path)
text_chunks = chunk_text(extracted_text, 1000, 200)
Creating Embeddings for Text Chunks
Embeddings transform text into numerical vectors, which allow for efficient similarity search.
def create_embeddings(text, model="BAAI/bge-en-icl"):
response = client.embeddings.create(
model=model,
input=text
)
return response
response = create_embeddings(text_chunks)
Implementing Context-Aware Semantic Search
We modify retrieval to include neighboring chunks for better context.
def cosine_similarity(vec1, vec2):
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
def context_enriched_search(query, text_chunks, embeddings, k=1, context_size=1):
query_embedding = create_embeddings(query).data[0].embedding
similarity_scores = []
for i, chunk_embedding in enumerate(embeddings):
similarity_score = cosine_similarity(np.array(query_embedding), np.array(chunk_embedding.embedding))
similarity_scores.append((i, similarity_score))
similarity_scores.sort(key=lambda x: x[1], reverse=True)
top_index = similarity_scores[0][0]
start = max(0, top_index - context_size)
end = min(len(text_chunks), top_index + context_size + 1)
return [text_chunks[i] for i in range(start, end)]
Running a Query with Context Retrieval
We now test the context-enriched retrieval.
with open('data/val.json') as f:
data = json.load(f)
query = data[0]['question']
top_chunks = context_enriched_search(query, text_chunks, response.data, k=1, context_size=1)
print("Query:", query)
for i, chunk in enumerate(top_chunks):
print(f"Context {i + 1}:\n{chunk}\n=====================================")
Generating a Response Using Retrieved Context
We now generate a response using LLM.
system_prompt = "You are an AI assistant that strictly answers based on the given context. If the answer cannot be derived directly from the provided context, respond with: 'I do not have enough information to answer that.'"
def generate_response(system_prompt, user_message, model="meta-llama/Llama-3.2-3B-Instruct"):
response = client.chat.completions.create(
model=model,
temperature=0,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message}
]
)
return response
user_prompt = "\n".join([f"Context {i + 1}:\n{chunk}\n=====================================\n" for i, chunk in enumerate(top_chunks)])
user_prompt = f"{user_prompt}\nQuestion: {query}"
ai_response = generate_response(system_prompt, user_prompt)
Evaluating the AI Response
We compare the AI response with the expected answer and assign a score.
evaluate_system_prompt = "You are an intelligent evaluation system tasked with assessing the AI assistant's responses. If the AI assistant's response is very close to the true response, assign a score of 1. If the response is incorrect or unsatisfactory in relation to the true response, assign a score of 0. If the response is partially aligned with the true response, assign a score of 0.5."
evaluation_prompt = f"User Query: {query}\nAI Response:\n{ai_response.choices[0].message.content}\nTrue Response: {data[0]['ideal_answer']}\n{evaluate_system_prompt}"
evaluation_response = generate_response(evaluate_system_prompt, evaluation_prompt)
print(evaluation_response.choices[0].message.content)