Query Transformations for Enhanced RAG Systems
Query Transformations for Enhanced RAG Systems
This notebook implements three query transformation techniques to enhance retrieval performance in RAG systems without relying on specialized libraries like LangChain. By modifying user queries, we can significantly improve the relevance and comprehensiveness of retrieved information.
Key Transformation Techniques
- Query Rewriting: Makes queries more specific and detailed for better search precision.
- Step-back Prompting: Generates broader queries to retrieve useful contextual information.
- Sub-query Decomposition: Breaks complex queries into simpler components for comprehensive retrieval.
Setting Up the Environment
We begin by importing necessary libraries.
import fitz
import os
import numpy as np
import json
from openai import OpenAI
Setting Up the OpenAI API Client
We initialize the OpenAI client to generate embeddings and responses.
client = OpenAI(
base_url="https://api.studio.nebius.com/v1/",
api_key=os.getenv("OPENAI_API_KEY")
)
Implementing Query Transformation Techniques
1. Query Rewriting
This technique makes queries more specific and detailed to improve precision in retrieval.
def rewrite_query(original_query, model="meta-llama/Llama-3.2-3B-Instruct"):
system_prompt = "You are an AI assistant specialized in improving search queries. Your task is to rewrite user queries to be more specific, detailed, and likely to retrieve relevant information."
user_prompt = f"""
Rewrite the following query to make it more specific and detailed. Include relevant terms and concepts that might help in retrieving accurate information.
Original query: {original_query}
Rewritten query:
"""
response = client.chat.completions.create(
model=model,
temperature=0.0,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
)
return response.choices[0].message.content.strip()
2. Step-back Prompting
This technique generates broader queries to retrieve contextual background information.
def generate_step_back_query(original_query, model="meta-llama/Llama-3.2-3B-Instruct"):
system_prompt = "You are an AI assistant specialized in search strategies. Your task is to generate broader, more general versions of specific queries to retrieve relevant background information."
user_prompt = f"""
Generate a broader, more general version of the following query that could help retrieve useful background information.
Original query: {original_query}
Step-back query:
"""
response = client.chat.completions.create(
model=model,
temperature=0.1,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
)
return response.choices[0].message.content.strip()
3. Sub-query Decomposition
This technique breaks down complex queries into simpler components for comprehensive retrieval.
def decompose_query(original_query, num_subqueries=4, model="meta-llama/Llama-3.2-3B-Instruct"):
system_prompt = "You are an AI assistant specialized in breaking down complex questions. Your task is to decompose complex queries into simpler sub-questions that, when answered together, address the original query."
user_prompt = f"""
Break down the following complex query into {num_subqueries} simpler sub-queries. Each sub-query should focus on a different aspect of the original question.
Original query: {original_query}
Generate {num_subqueries} sub-queries, one per line, in this format:
1. [First sub-query]
2. [Second sub-query]
And so on...
"""
response = client.chat.completions.create(
model=model,
temperature=0.2,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
)
content = response.choices[0].message.content.strip()
lines = content.split("\n")
sub_queries = []
for line in lines:
if line.strip() and any(line.strip().startswith(f"{i}.") for i in range(1, 10)):
query = line.strip()
query = query[query.find(".")+1:].strip()
sub_queries.append(query)
return sub_queries
Demonstrating Query Transformation Techniques
Let's apply these techniques to an example query.
original_query = "What are the impacts of AI on job automation and employment?"
print("Original Query:", original_query)
rewritten_query = rewrite_query(original_query)
print("\n1. Rewritten Query:")
print(rewritten_query)
step_back_query = generate_step_back_query(original_query)
print("\n2. Step-back Query:")
print(step_back_query)
sub_queries = decompose_query(original_query, num_subqueries=4)
print("\n3. Sub-queries:")
for i, query in enumerate(sub_queries, 1):
print(f" {i}. {query}")
Building a Simple Vector Store
To demonstrate how query transformations integrate with retrieval, let's implement a simple vector store.
class SimpleVectorStore:
def __init__(self):
self.vectors = []
self.texts = []
self.metadata = []
def add_item(self, text, embedding, metadata=None):
self.vectors.append(np.array(embedding))
self.texts.append(text)
self.metadata.append(metadata or {})
def similarity_search(self, query_embedding, k=5):
if not self.vectors:
return []
query_vector = np.array(query_embedding)
similarities = []
for i, vector in enumerate(self.vectors):
similarity = np.dot(query_vector, vector) / (np.linalg.norm(query_vector) * np.linalg.norm(vector))
similarities.append((i, similarity))
similarities.sort(key=lambda x: x[1], reverse=True)
results = []
for i in range(min(k, len(similarities))):
idx, score = similarities[i]
results.append({
"text": self.texts[idx],
"metadata": self.metadata[idx],
"similarity": score
})
return results
Creating Embeddings
def create_embeddings(text, model="BAAI/bge-en-icl"):
input_text = text if isinstance(text, list) else [text]
response = client.embeddings.create(
model=model,
input=input_text
)
if isinstance(text, str):
return response.data[0].embedding
return [item.embedding for item in response.data]
Implementing RAG with Query Transformations
def extract_text_from_pdf(pdf_path):
mypdf = fitz.open(pdf_path)
all_text = ""
for page_num in range(mypdf.page_count):
page = mypdf[page_num]
text = page.get_text("text")
all_text += text
return all_text
def chunk_text(text, n=1000, overlap=200):
chunks = []
for i in range(0, len(text), n - overlap):
chunks.append(text[i:i + n])
return chunks
def process_document(pdf_path, chunk_size=1000, chunk_overlap=200):
print("Extracting text from PDF...")
extracted_text = extract_text_from_pdf(pdf_path)
print("Chunking text...")
chunks = chunk_text(extracted_text, chunk_size, chunk_overlap)
print(f"Created {len(chunks)} text chunks")
print("Creating embeddings for chunks...")
chunk_embeddings = create_embeddings(chunks)
store = SimpleVectorStore()
for i, (chunk, embedding) in enumerate(zip(chunks, chunk_embeddings)):
store.add_item(
text=chunk,
embedding=embedding,
metadata={"index": i, "source": pdf_path}
)
print(f"Added {len(chunks)} chunks to the vector store")
return store
RAG with Query Transformations
def transformed_search(query, vector_store, transformation_type, top_k=3):
print(f"Transformation type: {transformation_type}")
print(f"Original query: {query}")
results = []
if transformation_type == "rewrite":
transformed_query = rewrite_query(query)
print(f"Rewritten query: {transformed_query}")
query_embedding = create_embeddings(transformed_query)
results = vector_store.similarity_search(query_embedding, k=top_k)
elif transformation_type == "step_back":
transformed_query = generate_step_back_query(query)
print(f"Step-back query: {transformed_query}")
query_embedding = create_embeddings(transformed_query)
results = vector_store.similarity_search(query_embedding, k=top_k)
elif transformation_type == "decompose":
sub_queries = decompose_query(query)
print("Decomposed into sub-queries:")
for i, sub_q in enumerate(sub_queries, 1):
print(f"{i}. {sub_q}")
sub_query_embeddings = create_embeddings(sub_queries)
all_results = []
for i, embedding in enumerate(sub_query_embeddings):
sub_results = vector_store.similarity_search(embedding, k=2)
all_results.extend(sub_results)
seen_texts = {}
for result in all_results:
text = result["text"]
if text not in seen_texts or result["similarity"] > seen_texts[text]["similarity"]:
seen_texts[text] = result
results = sorted(seen_texts.values(), key=lambda x: x["similarity"], reverse=True)[:top_k]
else:
query_embedding = create_embeddings(query)
results = vector_store.similarity_search(query_embedding, k=top_k)
return results
Generating a Response with Transformed Queries
def generate_response(query, context, model="meta-llama/Llama-3.2-3B-Instruct"):
system_prompt = "You are a helpful AI assistant. Answer the user's question based only on the provided context. If you cannot find the answer in the context, state that you don't have enough information."
user_prompt = f"""
Context:
{context}
Question: {query}
Please provide a comprehensive answer based only on the context above.
"""
response = client.chat.completions.create(
model=model,
temperature=0,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
)
return response.choices[0].message.content.strip()
Running the Complete RAG Pipeline with Query Transformations
def rag_with_query_transformation(pdf_path, query, transformation_type=None):
vector_store = process_document(pdf_path)
if transformation_type:
results = transformed_search(query, vector_store, transformation_type)
else:
query_embedding = create_embeddings(query)
results = vector_store.similarity_search(query_embedding, k=3)
context = "\n\n".join([f"PASSAGE {i+1}:\n{result['text']}" for i, result in enumerate(results)])
response = generate_response(query, context)
return {
"original_query": query,
"transformation_type": transformation_type,
"context": context,
"response": response
}
Evaluating Transformation Techniques
def compare_responses(results, reference_answer, model="meta-llama/Llama-3.2-3B-Instruct"):
system_prompt = """You are an expert evaluator of RAG systems.
Your task is to compare different responses generated using various query transformation techniques
and determine which technique produced the best response compared to the reference answer."""
comparison_text = f"""Reference Answer: {reference_answer}\n\n"""
for technique, result in results.items():
comparison_text += f"{technique.capitalize()} Query Response:\n{result['response']}\n\n"
user_prompt = f"""
{comparison_text}
Compare the responses generated by different query transformation techniques to the reference answer.
For each technique (original, rewrite, step_back, decompose):
1. Score the response from 1-10 based on accuracy, completeness, and relevance
2. Identify strengths and weaknesses
Then rank the techniques from best to worst and explain which technique performed best overall and why.
"""
response = client.chat.completions.create(
model=model,
temperature=0,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
)
print("\n===== EVALUATION RESULTS =====")
print(response.choices[0].message.content)
print("=============================")
def evaluate_transformations(pdf_path, query, reference_answer=None):
transformation_types = [None, "rewrite", "step_back", "decompose"]
results = {}
for transformation_type in transformation_types:
type_name = transformation_type if transformation_type else "original"
print(f"\n===== Running RAG with {type_name} query =====")
result = rag_with_query_transformation(pdf_path, query, transformation_type)
results[type_name] = result
print(f"Response with {type_name} query:")
print(result["response"])
print("=" * 50)
if reference_answer:
compare_responses(results, reference_answer)
return results
Evaluation of Query Transformations
with open('data/val.json') as f:
data = json.load(f)
query = data[0]['question']
reference_answer = data[0]['ideal_answer']
pdf_path = "data/AI_Information.pdf"
evaluation_results = evaluate_transformations(pdf_path, query, reference_answer)