Oreoluwa
Graph RAG: Graph-Enhanced Retrieval-Augmented Generation
June 25, 2025
30 min read

Graph RAG: Graph-Enhanced Retrieval-Augmented Generation

RAG
NLP
AI
Knowledge Graph
Graph RAG

Graph RAG: Graph-Enhanced Retrieval-Augmented Generation

In this notebook, I implement Graph RAG - a technique that enhances traditional RAG systems by organizing knowledge as a connected graph rather than a flat collection of documents. This allows the system to navigate related concepts and retrieve more contextually relevant information than standard vector similarity approaches.

Key Benefits of Graph RAG

  • Preserves relationships between pieces of information
  • Enables traversal through connected concepts to find relevant context
  • Improves handling of complex, multi-part queries
  • Provides better explainability through visualized knowledge paths

Setting Up the Environment

We begin by importing necessary libraries.

import os
import numpy as np
import json
import fitz # PyMuPDF
from openai import OpenAI
from typing import List, Dict, Tuple, Any
import networkx as nx
import matplotlib.pyplot as plt
import heapq
from collections import defaultdict
import re
from PIL import Image
import io
from sklearn.metrics.pairwise import cosine_similarity

Setting Up the OpenAI API Client

We initialize the OpenAI client to generate embeddings and responses.

client = OpenAI(
    base_url="https://api.studio.nebius.com/v1/",
    api_key=os.getenv("OPENAI_API_KEY")
)

Document Processing Functions

def extract_text_from_pdf(pdf_path):
    print(f"Extracting text from {pdf_path}...")
    pdf_document = fitz.open(pdf_path)
    text = ""
    for page_num in range(pdf_document.page_count):
        page = pdf_document[page_num]
        text += page.get_text()
    return text

def chunk_text(text, chunk_size=1000, overlap=200):
    chunks = []
    for i in range(0, len(text), chunk_size - overlap):
        chunk_text = text[i:i + chunk_size]
        if chunk_text:
            chunks.append({
                "text": chunk_text,
                "index": len(chunks),
                "start_pos": i,
                "end_pos": i + len(chunk_text)
            })
    print(f"Created {len(chunks)} text chunks")
    return chunks

Creating Embeddings

def create_embeddings(texts, model="BAAI/bge-en-icl"):
    if not texts:
        return []
    batch_size = 100
    all_embeddings = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i + batch_size]
        response = client.embeddings.create(
            model=model,
            input=batch
        )
        batch_embeddings = [item.embedding for item in response.data]
        all_embeddings.extend(batch_embeddings)
    return all_embeddings

Knowledge Graph Construction

def extract_concepts(text):
    system_message = """Extract key concepts and entities from the provided text.
    Return ONLY a list of 5-10 key terms, entities, or concepts that are most important in this text.
    Format your response as a JSON array of strings."""
    response = client.chat.completions.create(
        model="meta-llama/Llama-3.2-3B-Instruct",
        messages=[
            {"role": "system", "content": system_message},
            {"role": "user", "content": f"Extract key concepts from:\n\n{text[:3000]}"}
        ],
        temperature=0.0,
        response_format={"type": "json_object"}
    )
    try:
        concepts_json = json.loads(response.choices[0].message.content)
        concepts = concepts_json.get("concepts", [])
        if not concepts and "concepts" not in concepts_json:
            for key, value in concepts_json.items():
                if isinstance(value, list):
                    concepts = value
                    break
        return concepts
    except (json.JSONDecodeError, AttributeError):
        content = response.choices[0].message.content
        matches = re.findall(r'\[(.*?)\]', content, re.DOTALL)
        if matches:
            items = re.findall(r'"([^"]*)"', matches[0])
            return items
        return []

def build_knowledge_graph(chunks):
    print("Building knowledge graph...")
    graph = nx.Graph()
    texts = [chunk["text"] for chunk in chunks]
    print("Creating embeddings for chunks...")
    embeddings = create_embeddings(texts)
    print("Adding nodes to the graph...")
    for i, chunk in enumerate(chunks):
        print(f"Extracting concepts for chunk {i+1}/{len(chunks)}...")
        concepts = extract_concepts(chunk["text"])
        graph.add_node(i,
                       text=chunk["text"],
                       concepts=concepts,
                       embedding=embeddings[i])
    print("Creating edges between nodes...")
    for i in range(len(chunks)):
        node_concepts = set(graph.nodes[i]["concepts"])
        for j in range(i + 1, len(chunks)):
            other_concepts = set(graph.nodes[j]["concepts"])
            shared_concepts = node_concepts.intersection(other_concepts)
            if shared_concepts:
                similarity = np.dot(embeddings[i], embeddings[j]) / (np.linalg.norm(embeddings[i]) * np.linalg.norm(embeddings[j]))
                concept_score = len(shared_concepts) / min(len(node_concepts), len(other_concepts))
                edge_weight = 0.7 * similarity + 0.3 * concept_score
                if edge_weight > 0.6:
                    graph.add_edge(i, j,
                                   weight=edge_weight,
                                   similarity=similarity,
                                   shared_concepts=list(shared_concepts))
    print(f"Knowledge graph built with {graph.number_of_nodes()} nodes and {graph.number_of_edges()} edges")
    return graph, embeddings

Graph Traversal and Query Processing

def traverse_graph(query, graph, embeddings, top_k=5, max_depth=3):
    print(f"Traversing graph for query: {query}")
    query_embedding = create_embeddings(query)
    similarities = []
    for i, node_embedding in enumerate(embeddings):
        similarity = np.dot(query_embedding, node_embedding) / (np.linalg.norm(query_embedding) * np.linalg.norm(node_embedding))
        similarities.append((i, similarity))
    similarities.sort(key=lambda x: x[1], reverse=True)
    starting_nodes = [node for node, _ in similarities[:top_k]]
    print(f"Starting traversal from {len(starting_nodes)} nodes")
    visited = set()
    traversal_path = []
    results = []
    queue = []
    for node in starting_nodes:
        heapq.heappush(queue, (-similarities[node][1], node))
    while queue and len(results) < (top_k * 3):
        _, node = heapq.heappop(queue)
        if node in visited:
            continue
        visited.add(node)
        traversal_path.append(node)
        results.append({
            "text": graph.nodes[node]["text"],
            "concepts": graph.nodes[node]["concepts"],
            "node_id": node
        })
        if len(traversal_path) < max_depth:
            neighbors = [(neighbor, graph[node][neighbor]["weight"])
                         for neighbor in graph.neighbors(node)
                         if neighbor not in visited]
            for neighbor, weight in sorted(neighbors, key=lambda x: x[1], reverse=True):
                heapq.heappush(queue, (-weight, neighbor))
    print(f"Graph traversal found {len(results)} relevant chunks")
    return results, traversal_path

Response Generation

def generate_response(query, context_chunks):
    context_texts = [chunk["text"] for chunk in context_chunks]
    combined_context = "\n\n---\n\n".join(context_texts)
    max_context = 14000
    if len(combined_context) > max_context:
        combined_context = combined_context[:max_context] + "... [truncated]"
    system_message = """You are a helpful AI assistant. Answer the user's question based on the provided context.
If the information is not in the context, say so. Refer to specific parts of the context in your answer when possible."""
    response = client.chat.completions.create(
        model="meta-llama/Llama-3.2-3B-Instruct",
        messages=[
            {"role": "system", "content": system_message},
            {"role": "user", "content": f"Context:\n{combined_context}\n\nQuestion: {query}"}
        ],
        temperature=0.2
    )
    return response.choices[0].message.content

Visualization

def visualize_graph_traversal(graph, traversal_path):
    plt.figure(figsize=(12, 10))
    node_color = ['lightblue'] * graph.number_of_nodes()
    for node in traversal_path:
        node_color[node] = 'lightgreen'
    if traversal_path:
        node_color[traversal_path[0]] = 'green'
        node_color[traversal_path[-1]] = 'red'
    pos = nx.spring_layout(graph, k=0.5, iterations=50, seed=42)
    nx.draw_networkx_nodes(graph, pos, node_color=node_color, node_size=500, alpha=0.8)
    for u, v, data in graph.edges(data=True):
        weight = data.get('weight', 1.0)
        nx.draw_networkx_edges(graph, pos, edgelist=[(u, v)], width=weight*2, alpha=0.6)
    traversal_edges = [(traversal_path[i], traversal_path[i+1])
                       for i in range(len(traversal_path)-1)]
    nx.draw_networkx_edges(graph, pos, edgelist=traversal_edges,
                           width=3, alpha=0.8, edge_color='red',
                           style='dashed', arrows=True)
    labels = {}
    for node in graph.nodes():
        concepts = graph.nodes[node]['concepts']
        label = concepts[0] if concepts else f"Node {node}"
        labels[node] = f"{node}: {label}"
    nx.draw_networkx_labels(graph, pos, labels=labels, font_size=8)
    plt.title("Knowledge Graph with Traversal Path")
    plt.axis('off')
    plt.tight_layout()
    plt.show()

Complete Graph RAG Pipeline

def graph_rag_pipeline(pdf_path, query, chunk_size=1000, chunk_overlap=200, top_k=3):
    text = extract_text_from_pdf(pdf_path)
    chunks = chunk_text(text, chunk_size, chunk_overlap)
    graph, embeddings = build_knowledge_graph(chunks)
    relevant_chunks, traversal_path = traverse_graph(query, graph, embeddings, top_k)
    response = generate_response(query, relevant_chunks)
    visualize_graph_traversal(graph, traversal_path)
    return {
        "query": query,
        "response": response,
        "relevant_chunks": relevant_chunks,
        "traversal_path": traversal_path,
        "graph": graph
    }

Evaluation Function

def evaluate_graph_rag(pdf_path, test_queries, reference_answers=None):
    text = extract_text_from_pdf(pdf_path)
    chunks = chunk_text(text)
    graph, embeddings = build_knowledge_graph(chunks)
    results = []
    for i, query in enumerate(test_queries):
        print(f"\n\n=== Evaluating Query {i+1}/{len(test_queries)} ===")
        print(f"Query: {query}")
        relevant_chunks, traversal_path = traverse_graph(query, graph, embeddings)
        response = generate_response(query, relevant_chunks)
        reference = None
        comparison = None
        if reference_answers and i < len(reference_answers):
            reference = reference_answers[i]
            comparison = compare_with_reference(response, reference, query)
        results.append({
            "query": query,
            "response": response,
            "reference_answer": reference,
            "comparison": comparison,
            "traversal_path_length": len(traversal_path),
            "relevant_chunks_count": len(relevant_chunks)
        })
        print(f"\nResponse: {response}\n")
        if comparison:
            print(f"Comparison: {comparison}\n")
    return {
        "results": results,
        "graph_stats": {
            "nodes": graph.number_of_nodes(),
            "edges": graph.number_of_edges(),
            "avg_degree": sum(dict(graph.degree()).values()) / graph.number_of_nodes()
        }
    }

def compare_with_reference(response, reference, query):
    system_message = """Compare the AI-generated response with the reference answer.
Evaluate based on: correctness, completeness, and relevance to the query.
Provide a brief analysis (2-3 sentences) of how well the generated response matches the reference."""
    prompt = f"""
    Query: {query}
    AI-generated response:
    {response}
    Reference answer:
    {reference}
    How well does the AI response match the reference?
    """
    comparison = client.chat.completions.create(
        model="meta-llama/Llama-3.2-3B-Instruct",
        messages=[
            {"role": "system", "content": system_message},
            {"role": "user", "content": prompt}
        ],
        temperature=0.0
    )
    return comparison.choices[0].message.content

Evaluation of Graph RAG on a Sample PDF Document

pdf_path = "data/AI_Information.pdf"
test_queries = [
    "What are the key applications of transformers in natural language processing?",
]
reference_answers = [
    "Transformer models have revolutionized natural language processing with applications including machine translation, text summarization, question answering, sentiment analysis, and text generation. They excel at capturing long-range dependencies in text and have become the foundation for models like BERT, GPT, and T5.",
]
evaluation_results = evaluate_graph_rag(
    pdf_path=pdf_path,
    test_queries=test_queries,
    reference_answers=reference_answers
)