Oreoluwa
Corrective RAG (CRAG) Implementation
June 4, 2025
25 min read

Corrective RAG (CRAG) Implementation

RAG
NLP
AI
CRAG
Web Search

Corrective RAG (CRAG) Implementation

In this notebook, I implement Corrective RAG - an advanced approach that dynamically evaluates retrieved information and corrects the retrieval process when necessary, using web search as a fallback.

CRAG improves on traditional RAG by:

  • Evaluating retrieved content before using it
  • Dynamically switching between knowledge sources based on relevance
  • Correcting the retrieval with web search when local knowledge is insufficient
  • Combining information from multiple sources when appropriate

Setting Up the Environment

We begin by importing necessary libraries.

import os
import numpy as np
import json
import fitz # PyMuPDF
from openai import OpenAI
import requests
from typing import List, Dict, Tuple, Any
import re
from urllib.parse import quote_plus
import time

Setting Up the OpenAI API Client

We initialize the OpenAI client to generate embeddings and responses.

client = OpenAI(
    base_url="https://api.studio.nebius.com/v1/",
    api_key=os.getenv("OPENAI_API_KEY")
)

Document Processing Functions

def extract_text_from_pdf(pdf_path):
    print(f"Extracting text from {pdf_path}...")
    pdf = fitz.open(pdf_path)
    text = ""
    for page_num in range(len(pdf)):
        page = pdf[page_num]
        text += page.get_text()
    return text

def chunk_text(text, chunk_size=1000, overlap=200):
    chunks = []
    for i in range(0, len(text), chunk_size - overlap):
        chunk_text = text[i:i + chunk_size]
        if chunk_text:
            chunks.append({
                "text": chunk_text,
                "metadata": {
                    "start_pos": i,
                    "end_pos": i + len(chunk_text),
                    "source_type": "document"
                }
            })
    print(f"Created {len(chunks)} text chunks")
    return chunks

Simple Vector Store Implementation

class SimpleVectorStore:
    def __init__(self):
        self.vectors = []
        self.texts = []
        self.metadata = []
    def add_item(self, text, embedding, metadata=None):
        self.vectors.append(np.array(embedding))
        self.texts.append(text)
        self.metadata.append(metadata or {})
    def add_items(self, items, embeddings):
        for i, (item, embedding) in enumerate(zip(items, embeddings)):
            self.add_item(
                text=item["text"],
                embedding=embedding,
                metadata=item.get("metadata", {})
            )
    def similarity_search(self, query_embedding, k=5):
        if not self.vectors:
            return []
        query_vector = np.array(query_embedding)
        similarities = []
        for i, vector in enumerate(self.vectors):
            similarity = np.dot(query_vector, vector) / (np.linalg.norm(query_vector) * np.linalg.norm(vector))
            similarities.append((i, similarity))
        similarities.sort(key=lambda x: x[1], reverse=True)
        results = []
        for i in range(min(k, len(similarities))):
            idx, score = similarities[i]
            results.append({
                "text": self.texts[idx],
                "metadata": self.metadata[idx],
                "similarity": float(score)
            })
        return results

Creating Embeddings

def create_embeddings(texts, model="text-embedding-3-small"):
    input_texts = texts if isinstance(texts, list) else [texts]
    batch_size = 100
    all_embeddings = []
    for i in range(0, len(input_texts), batch_size):
        batch = input_texts[i:i + batch_size]
        response = client.embeddings.create(
            model=model,
            input=batch
        )
        batch_embeddings = [item.embedding for item in response.data]
        all_embeddings.extend(batch_embeddings)
    if isinstance(texts, str):
        return all_embeddings[0]
    return all_embeddings

Document Processing Pipeline

def process_document(pdf_path, chunk_size=1000, chunk_overlap=200):
    text = extract_text_from_pdf(pdf_path)
    chunks = chunk_text(text, chunk_size, chunk_overlap)
    print("Creating embeddings for chunks...")
    chunk_texts = [chunk["text"] for chunk in chunks]
    chunk_embeddings = create_embeddings(chunk_texts)
    vector_store = SimpleVectorStore()
    vector_store.add_items(chunks, chunk_embeddings)
    print(f"Vector store created with {len(chunks)} chunks")
    return vector_store

Relevance Evaluation Function

def evaluate_document_relevance(query, document):
    system_prompt = """
    You are an expert at evaluating document relevance. 
    Rate how relevant the given document is to the query on a scale from 0 to 1.
    0 means completely irrelevant, 1 means perfectly relevant.
    Provide ONLY the score as a float between 0 and 1.
    """
    user_prompt = f"Query: {query}\n\nDocument: {document}"
    try:
        response = client.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt}
            ],
            temperature=0,
            max_tokens=5
        )
        score_text = response.choices[0].message.content.strip()
        score_match = re.search(r'(\d+(\.\d+)?)', score_text)
        if score_match:
            return float(score_match.group(1))
        return 0.5
    except Exception as e:
        print(f"Error evaluating document relevance: {e}")
        return 0.5

Web Search Function

def duck_duck_go_search(query, num_results=3):
    encoded_query = quote_plus(query)
    url = f"https://api.duckduckgo.com/?q={encoded_query}&format=json"
    try:
        response = requests.get(url, headers={
            "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
        })
        data = response.json()
        results_text = ""
        sources = []
        if data.get("AbstractText"):
            results_text += f"{data['AbstractText']}\n\n"
            sources.append({
                "title": data.get("AbstractSource", "Wikipedia"),
                "url": data.get("AbstractURL", "")
            })
        for topic in data.get("RelatedTopics", [])[:num_results]:
            if "Text" in topic and "FirstURL" in topic:
                results_text += f"{topic['Text']}\n\n"
                sources.append({
                    "title": topic.get("Text", "").split(" - ")[0],
                    "url": topic.get("FirstURL", "")
                })
        return results_text, sources
    except Exception as e:
        print(f"Error performing web search: {e}")
        try:
            backup_url = f"https://serpapi.com/search.json?q={encoded_query}&engine=duckduckgo"
            response = requests.get(backup_url)
            data = response.json()
            results_text = ""
            sources = []
            for result in data.get("organic_results", [])[:num_results]:
                results_text += f"{result.get('title', '')}: {result.get('snippet', '')}\n\n"
                sources.append({
                    "title": result.get("title", ""),
                    "url": result.get("link", "")
                })
            return results_text, sources
        except Exception as backup_error:
            print(f"Backup search also failed: {backup_error}")
            return "Failed to retrieve search results.", []

def rewrite_search_query(query):
    system_prompt = """
    You are an expert at creating effective search queries.
    Rewrite the given query to make it more suitable for a web search engine.
    Focus on keywords and facts, remove unnecessary words, and make it concise.
    """
    user_prompt = f"Original query: {query}\n\nRewritten query:"
    try:
        response = client.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt}
            ],
            temperature=0.3,
            max_tokens=50
        )
        return response.choices[0].message.content.strip()
    except Exception as e:
        print(f"Error rewriting search query: {e}")
        return query

def perform_web_search(query):
    rewritten_query = rewrite_search_query(query)
    print(f"Rewritten search query: {rewritten_query}")
    results_text, sources = duck_duck_go_search(rewritten_query)
    return results_text, sources

Knowledge Refinement Function

def refine_knowledge(text):
    system_prompt = """
    Extract the key information from the following text as a set of clear, concise bullet points.
    Focus on the most relevant facts and important details.
    Format your response as a bulleted list with each point on a new line starting with "• ".
    """
    try:
        response = client.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": f"Text to refine:\n\n{text}"}
            ],
            temperature=0.3
        )
        return response.choices[0].message.content.strip()
    except Exception as e:
        print(f"Error refining knowledge: {e}")
        return text

Core CRAG Process

def crag_process(query, vector_store, k=3):
    print(f"\n=== Processing query with CRAG: {query} ===\n")
    print("Retrieving initial documents...")
    query_embedding = create_embeddings(query)
    retrieved_docs = vector_store.similarity_search(query_embedding, k=k)
    print("Evaluating document relevance...")
    relevance_scores = []
    for doc in retrieved_docs:
        score = evaluate_document_relevance(query, doc["text"])
        relevance_scores.append(score)
        doc["relevance"] = score
        print(f"Document scored {score:.2f} relevance")
    max_score = max(relevance_scores) if relevance_scores else 0
    best_doc_idx = relevance_scores.index(max_score) if relevance_scores else -1
    sources = []
    final_knowledge = ""
    if max_score > 0.7:
        print(f"High relevance ({max_score:.2f}) - Using document directly")
        best_doc = retrieved_docs[best_doc_idx]["text"]
        final_knowledge = best_doc
        sources.append({
            "title": "Document",
            "url": ""
        })
    elif max_score < 0.3:
        print(f"Low relevance ({max_score:.2f}) - Performing web search")
        web_results, web_sources = perform_web_search(query)
        final_knowledge = refine_knowledge(web_results)
        sources.extend(web_sources)
    else:
        print(f"Medium relevance ({max_score:.2f}) - Combining document with web search")
        best_doc = retrieved_docs[best_doc_idx]["text"]
        refined_doc = refine_knowledge(best_doc)
        web_results, web_sources = perform_web_search(query)
        refined_web = refine_knowledge(web_results)
        final_knowledge = f"From document:\n{refined_doc}\n\nFrom web search:\n{refined_web}"
        sources.append({
            "title": "Document",
            "url": ""
        })
        sources.extend(web_sources)
    print("Generating final response...")
    response = generate_response(query, final_knowledge, sources)
    return {
        "query": query,
        "response": response,
        "retrieved_docs": retrieved_docs,
        "relevance_scores": relevance_scores,
        "max_relevance": max_score,
        "final_knowledge": final_knowledge,
        "sources": sources
    }

Response Generation

def generate_response(query, knowledge, sources):
    sources_text = ""
    for source in sources:
        title = source.get("title", "Unknown Source")
        url = source.get("url", "")
        if url:
            sources_text += f"- {title}: {url}\n"
        else:
            sources_text += f"- {title}\n"
    system_prompt = """
    You are a helpful AI assistant. Generate a comprehensive, informative response to the query based on the provided knowledge.
    Include all relevant information while keeping your answer clear and concise.
    If the knowledge doesn't fully answer the query, acknowledge this limitation.
    Include source attribution at the end of your response.
    """
    user_prompt = f"""
    Query: {query}
    Knowledge:
    {knowledge}
    Sources:
    {sources_text}
    Please provide an informative response to the query based on this information.
    Include the sources at the end of your response.
    """
    try:
        response = client.chat.completions.create(
            model="gpt-4",
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt}
            ],
            temperature=0.2
        )
        return response.choices[0].message.content.strip()
    except Exception as e:
        print(f"I apologize, but I encountered an error while generating a response to your query: '{query}'. The error was: {str(e)}")
        return f"I apologize, but I encountered an error while generating a response to your query: '{query}'. The error was: {str(e)}"

Evaluation Functions

def evaluate_crag_response(query, response, reference_answer=None):
    system_prompt = """
    You are an expert at evaluating the quality of responses to questions.
    Please evaluate the provided response based on the following criteria:
        1. Relevance (0-10): How directly does the response address the query?
    2. Accuracy (0-10): How factually correct is the information?
    3. Completeness (0-10): How thoroughly does the response answer all aspects of the query?
    4. Clarity (0-10): How clear and easy to understand is the response?
    5. Source Quality (0-10): How well does the response cite relevant sources?
    Return your evaluation as a JSON object with scores for each criterion and a brief explanation for each score.
    Also include an "overall_score" (0-10) and a brief "summary" of your evaluation.
    """
    user_prompt = f"""
    Query: {query}
    Response to evaluate:
    {response}
    """
    if reference_answer:
        user_prompt += f"""
    Reference answer (for comparison):
    {reference_answer}
    """
    try:
        evaluation_response = client.chat.completions.create(
            model="gpt-4",
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt}
            ],
            response_format={"type": "json_object"},
            temperature=0
        )
        evaluation = json.loads(evaluation_response.choices[0].message.content)
        return evaluation
    except Exception as e:
        print(f"Error evaluating response: {e}")
        return {
            "error": str(e),
            "overall_score": 0,
            "summary": "Evaluation failed due to an error."
        }

def compare_crag_vs_standard_rag(query, vector_store, reference_answer=None):
    print("\n=== Running CRAG ===")
    crag_result = crag_process(query, vector_store)
    crag_response = crag_result["response"]
    print("\n=== Running standard RAG ===")
    query_embedding = create_embeddings(query)
    retrieved_docs = vector_store.similarity_search(query_embedding, k=3)
    combined_text = "\n\n".join([doc["text"] for doc in retrieved_docs])
    standard_sources = [{"title": "Document", "url": ""}]
    standard_response = generate_response(query, combined_text, standard_sources)
    print("\n=== Evaluating CRAG response ===")
    crag_eval = evaluate_crag_response(query, crag_response, reference_answer)
    print("\n=== Evaluating standard RAG response ===")
    standard_eval = evaluate_crag_response(query, standard_response, reference_answer)
    print("\n=== Comparing approaches ===")
    comparison = compare_responses(query, crag_response, standard_response, reference_answer)
    return {
        "query": query,
        "crag_response": crag_response,
        "standard_response": standard_response,
        "reference_answer": reference_answer,
        "crag_evaluation": crag_eval,
        "standard_evaluation": standard_eval,
        "comparison": comparison
    }

def compare_responses(query, crag_response, standard_response, reference_answer=None):
    system_prompt = """
    You are an expert evaluator comparing two response generation approaches:
        1. CRAG (Corrective RAG): A system that evaluates document relevance and dynamically switches to web search when needed.
    2. Standard RAG: A system that directly retrieves documents based on embedding similarity and uses them for response generation.
    Compare the responses from these two systems based on:
    - Accuracy and factual correctness
    - Relevance to the query
    - Completeness of the answer
    - Clarity and organization
    - Source attribution quality
    Explain which approach performed better for this specific query and why.
    """
    user_prompt = f"""
    Query: {query}
    CRAG Response:
    {crag_response}
    Standard RAG Response:
    {standard_response}
    """
    if reference_answer:
        user_prompt += f"""
    Reference Answer:
    {reference_answer}
    """
    user_prompt += """
    Please provide a detailed comparison of these two responses, highlighting which approach performed better and why.
    """
    try:
        response = client.chat.completions.create(
            model="gpt-4",
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt}
            ],
            temperature=0
        )
        return response.choices[0].message.content.strip()
    except Exception as e:
        print(f"Error comparing responses: {e}")
        return f"Error comparing responses: {str(e)}"

Complete Evaluation Pipeline

def run_crag_evaluation(pdf_path, test_queries, reference_answers=None):
    vector_store = process_document(pdf_path)
    results = []
    for i, query in enumerate(test_queries):
        print(f"\n\n===== Evaluating Query {i+1}/{len(test_queries)} =====")
        print(f"Query: {query}")
        reference = None
        if reference_answers and i < len(reference_answers):
            reference = reference_answers[i]
        result = compare_crag_vs_standard_rag(query, vector_store, reference)
        results.append(result)
        print("\n=== Comparison ===")
        print(result["comparison"])
    overall_analysis = generate_overall_analysis(results)
    return {
        "results": results,
        "overall_analysis": overall_analysis
    }

def generate_overall_analysis(results):
    system_prompt = """
    You are an expert at evaluating information retrieval and response generation systems.
    Based on multiple test queries, provide an overall analysis comparing CRAG (Corrective RAG) 
    with standard RAG.
    Focus on:
    1. When CRAG performs better and why
    2. When standard RAG performs better and why
    3. The overall strengths and weaknesses of each approach
    4. Recommendations for when to use each approach
    """
    evaluations_summary = """
    for i, result in enumerate(results):
        evaluations_summary += f"Query {i+1}: {result['query']}\n"
        if 'crag_evaluation' in result and 'overall_score' in result['crag_evaluation']:
            crag_score = result['crag_evaluation'].get('overall_score', 'N/A')
            evaluations_summary += f"CRAG score: {crag_score}\n"
        if 'standard_evaluation' in result and 'overall_score' in result['standard_evaluation']:
            std_score = result['standard_evaluation'].get('overall_score', 'N/A')
            evaluations_summary += f"Standard RAG score: {std_score}\n"
        evaluations_summary += f"Comparison summary: {result['comparison'][:200]}...\n\n"
    user_prompt = f"""
    Based on the following evaluations comparing CRAG vs standard RAG across {len(results)} queries, 
    provide an overall analysis of these two approaches:
    {evaluations_summary}
    Please provide a comprehensive analysis of the relative strengths and weaknesses of CRAG 
    compared to standard RAG, focusing on when and why one approach outperforms the other.
    """
    try:
        response = client.chat.completions.create(
            model="gpt-4",
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt}
            ],
            temperature=0
        )
        return response.choices[0].message.content.strip()
    except Exception as e:
        print(f"Error generating overall analysis: {e}")
        return f"Error generating overall analysis: {str(e)}"

Evaluation of CRAG with Test Queries

pdf_path = "data/AI_Information.pdf"
test_queries = [
    "How does machine learning differ from traditional programming?",
]
reference_answers = [
    "Machine learning differs from traditional programming by having computers learn patterns from data rather than following explicit instructions. In traditional programming, developers write specific rules for the computer to follow, while in machine learning",
]
evaluation_results = run_crag_evaluation(pdf_path, test_queries, reference_answers)
print("\n=== Overall Analysis of CRAG vs Standard RAG ===")
print(evaluation_results["overall_analysis"])