Relevant Segment Extraction (RSE) for Enhanced RAG
Relevant Segment Extraction (RSE) for Enhanced RAG
In this notebook, we implement a Relevant Segment Extraction (RSE) technique to improve the context quality in our RAG system. Rather than simply retrieving a collection of isolated chunks, we identify and reconstruct continuous segments of text that provide better context to our language model.
Key Concept
Relevant chunks tend to be clustered together within documents. By identifying these clusters and preserving their continuity, we provide more coherent context for the LLM to work with.
Setting Up the Environment
We begin by importing necessary libraries.
import fitz
import os
import numpy as np
import json
from openai import OpenAI
import re
Extracting Text from a PDF File
To implement RAG, we first need a source of textual data. In this case, we extract text from a PDF file using the PyMuPDF library.
def extract_text_from_pdf(pdf_path):
mypdf = fitz.open(pdf_path)
all_text = ""
for page_num in range(mypdf.page_count):
page = mypdf[page_num]
text = page.get_text("text")
all_text += text
return all_text
Chunking the Extracted Text
Once we have the extracted text, we divide it into smaller, overlapping chunks to improve retrieval accuracy.
def chunk_text(text, chunk_size=800, overlap=0):
chunks = []
for i in range(0, len(text), chunk_size - overlap):
chunk = text[i:i + chunk_size]
if chunk:
chunks.append(chunk)
return chunks
Setting Up the OpenAI API Client
We initialize the OpenAI client to generate embeddings and responses.
client = OpenAI(
base_url="https://api.studio.nebius.com/v1/",
api_key=os.getenv("OPENAI_API_KEY")
)
Building a Simple Vector Store
let's implement a simple vector store.
class SimpleVectorStore:
def __init__(self, dimension=1536):
self.dimension = dimension
self.vectors = []
self.documents = []
self.metadata = []
def add_documents(self, documents, vectors=None, metadata=None):
if vectors is None:
vectors = [None] * len(documents)
if metadata is None:
metadata = [{} for _ in range(len(documents))]
for doc, vec, meta in zip(documents, vectors, metadata):
self.documents.append(doc)
self.vectors.append(vec)
self.metadata.append(meta)
def search(self, query_vector, top_k=5):
if not self.vectors or not self.documents:
return []
query_array = np.array(query_vector)
similarities = []
for i, vector in enumerate(self.vectors):
if vector is not None:
similarity = np.dot(query_array, vector) / (
np.linalg.norm(query_array) * np.linalg.norm(vector)
)
similarities.append((i, similarity))
similarities.sort(key=lambda x: x[1], reverse=True)
results = []
for i, score in similarities[:top_k]:
results.append({
"document": self.documents[i],
"score": float(score),
"metadata": self.metadata[i]
})
return results
Creating Embeddings for Text Chunks
Embeddings transform text into numerical vectors, which allow for efficient similarity search.
def create_embeddings(texts, model="BAAI/bge-en-icl"):
if not texts:
return []
batch_size = 100
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
response = client.embeddings.create(
input=batch,
model=model
)
batch_embeddings = [item.embedding for item in response.data]
all_embeddings.extend(batch_embeddings)
return all_embeddings
Processing Documents with RSE
Now let's implement the core RSE functionality.
def process_document(pdf_path, chunk_size=800):
print("Extracting text from document...")
text = extract_text_from_pdf(pdf_path)
print("Chunking text into non-overlapping segments...")
chunks = chunk_text(text, chunk_size=chunk_size, overlap=0)
print(f"Created {len(chunks)} chunks")
print("Generating embeddings for chunks...")
chunk_embeddings = create_embeddings(chunks)
vector_store = SimpleVectorStore()
metadata = [{"chunk_index": i, "source": pdf_path} for i in range(len(chunks))]
vector_store.add_documents(chunks, chunk_embeddings, metadata)
doc_info = {
"chunks": chunks,
"source": pdf_path,
}
return chunks, vector_store, doc_info
RSE Core Algorithm: Computing Chunk Values and Finding Best Segments
Now that we have the necessary functions to process a document and generate embeddings for its chunks, we can implement the core algorithm for RSE.
def calculate_chunk_values(query, chunks, vector_store, irrelevant_chunk_penalty=0.2):
query_embedding = create_embeddings([query])[0]
num_chunks = len(chunks)
results = vector_store.search(query_embedding, top_k=num_chunks)
relevance_scores = {result["metadata"]["chunk_index"]: result["score"] for result in results}
chunk_values = []
for i in range(num_chunks):
score = relevance_scores.get(i, 0.0)
value = score - irrelevant_chunk_penalty
chunk_values.append(value)
return chunk_values
def find_best_segments(chunk_values, max_segment_length=20, total_max_length=30, min_segment_value=0.2):
print("Finding optimal continuous text segments...")
best_segments = []
segment_scores = []
total_included_chunks = 0
while total_included_chunks < total_max_length:
best_score = min_segment_value
best_segment = None
for start in range(len(chunk_values)):
if any(start >= s[0] and start < s[1] for s in best_segments):
continue
for length in range(1, min(max_segment_length, len(chunk_values) - start) + 1):
end = start + length
if any(end > s[0] and end <= s[1] for s in best_segments):
continue
segment_value = sum(chunk_values[start:end])
if segment_value > best_score:
best_score = segment_value
best_segment = (start, end)
if best_segment:
best_segments.append(best_segment)
segment_scores.append(best_score)
total_included_chunks += best_segment[1] - best_segment[0]
print(f"Found segment {best_segment} with score {best_score:.4f}")
else:
break
best_segments = sorted(best_segments, key=lambda x: x[0])
return best_segments, segment_scores
Reconstructing and Using Segments for RAG
def reconstruct_segments(chunks, best_segments):
reconstructed_segments = []
for start, end in best_segments:
segment_text = " ".join(chunks[start:end])
reconstructed_segments.append({
"text": segment_text,
"segment_range": (start, end),
})
return reconstructed_segments
def format_segments_for_context(segments):
context = []
for i, segment in enumerate(segments):
segment_header = f"SEGMENT {i+1} (Chunks {segment['segment_range'][0]}-{segment['segment_range'][1]-1}):"
context.append(segment_header)
context.append(segment['text'])
context.append("-" * 80)
return "\n\n".join(context)
Generating Responses with RSE Context
def generate_response(query, context, model="meta-llama/Llama-3.2-3B-Instruct"):
print("Generating response using relevant segments as context...")
system_prompt = """You are a helpful assistant that answers questions based on the provided context.
The context consists of document segments that have been retrieved as relevant to the user's query.
Use the information from these segments to provide a comprehensive and accurate answer.
If the context doesn't contain relevant information to answer the question, say so clearly."""
user_prompt = f"""
Context:
{context}
Question: {query}
Please provide a helpful answer based on the context provided.
"""
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
temperature=0
)
return response.choices[0].message.content
Complete RSE Pipeline Function
def rag_with_rse(pdf_path, query, chunk_size=800, irrelevant_chunk_penalty=0.2):
print("\n=== STARTING RAG WITH RELEVANT SEGMENT EXTRACTION ===")
print(f"Query: {query}")
chunks, vector_store, doc_info = process_document(pdf_path, chunk_size)
print("\nCalculating relevance scores and chunk values...")
chunk_values = calculate_chunk_values(query, chunks, vector_store, irrelevant_chunk_penalty)
best_segments, scores = find_best_segments(
chunk_values,
max_segment_length=20,
total_max_length=30,
min_segment_value=0.2
)
print("\nReconstructing text segments from chunks...")
segments = reconstruct_segments(chunks, best_segments)
context = format_segments_for_context(segments)
response = generate_response(query, context)
result = {
"query": query,
"segments": segments,
"response": response
}
print("\n=== FINAL RESPONSE ===")
print(response)
return result
Comparing with Standard Retrieval
Let's implement a standard retrieval approach to compare with RSE:
def standard_top_k_retrieval(pdf_path, query, k=10, chunk_size=800):
print("\n=== STARTING STANDARD TOP-K RETRIEVAL ===")
print(f"Query: {query}")
chunks, vector_store, doc_info = process_document(pdf_path, chunk_size)
print("Creating query embedding and retrieving chunks...")
query_embedding = create_embeddings([query])[0]
results = vector_store.search(query_embedding, top_k=k)
retrieved_chunks = [result["document"] for result in results]
context = "\n\n".join([
f"CHUNK {i+1}:\n{chunk}"
for i, chunk in enumerate(retrieved_chunks)
])
response = generate_response(query, context)
result = {
"query": query,
"chunks": retrieved_chunks,
"response": response
}
print("\n=== FINAL RESPONSE ===")
print(response)
return result
Evaluation of RSE
def evaluate_methods(pdf_path, query, reference_answer=None):
print("\n========= EVALUATION =========\n")
rse_result = rag_with_rse(pdf_path, query)
standard_result = standard_top_k_retrieval(pdf_path, query)
if reference_answer:
print("\n=== COMPARING RESULTS ===")
evaluation_prompt = f"""
Query: {query}
Reference Answer:
{reference_answer}
Response from Standard Retrieval:
{standard_result["response"]}
Response from Relevant Segment Extraction:
{rse_result["response"]}
Compare these two responses against the reference answer. Which one is:
1. More accurate and comprehensive
2. Better at addressing the user's query
3. Less likely to include irrelevant information
Explain your reasoning for each point.
"""
print("Evaluating responses against reference answer...")
evaluation = client.chat.completions.create(
model="meta-llama/Llama-3.2-3B-Instruct",
messages=[
{"role": "system", "content": "You are an objective evaluator of RAG system responses."},
{"role": "user", "content": evaluation_prompt}
]
)
print("\n=== EVALUATION RESULTS ===")
print(evaluation.choices[0].message.content)
return {
"rse_result": rse_result,
"standard_result": standard_result
}
with open('data/val.json') as f:
data = json.load(f)
query = data[0]['question']
reference_answer = data[0]['ideal_answer']
pdf_path = "data/AI_Information.pdf"
results = evaluate_methods(pdf_path, query, reference_answer)