June 25, 2025
30 min read
Graph RAG: Graph-Enhanced Retrieval-Augmented Generation
RAG
NLP
AI
Knowledge Graph
Graph RAG
Graph RAG: Graph-Enhanced Retrieval-Augmented Generation
In this notebook, I implement Graph RAG - a technique that enhances traditional RAG systems by organizing knowledge as a connected graph rather than a flat collection of documents. This allows the system to navigate related concepts and retrieve more contextually relevant information than standard vector similarity approaches.
Key Benefits of Graph RAG
- Preserves relationships between pieces of information
- Enables traversal through connected concepts to find relevant context
- Improves handling of complex, multi-part queries
- Provides better explainability through visualized knowledge paths
Setting Up the Environment
We begin by importing necessary libraries.
import os
import numpy as np
import json
import fitz # PyMuPDF
from openai import OpenAI
from typing import List, Dict, Tuple, Any
import networkx as nx
import matplotlib.pyplot as plt
import heapq
from collections import defaultdict
import re
from PIL import Image
import io
from sklearn.metrics.pairwise import cosine_similarity
Setting Up the OpenAI API Client
We initialize the OpenAI client to generate embeddings and responses.
client = OpenAI(
base_url="https://api.studio.nebius.com/v1/",
api_key=os.getenv("OPENAI_API_KEY")
)
Document Processing Functions
def extract_text_from_pdf(pdf_path):
print(f"Extracting text from {pdf_path}...")
pdf_document = fitz.open(pdf_path)
text = ""
for page_num in range(pdf_document.page_count):
page = pdf_document[page_num]
text += page.get_text()
return text
def chunk_text(text, chunk_size=1000, overlap=200):
chunks = []
for i in range(0, len(text), chunk_size - overlap):
chunk_text = text[i:i + chunk_size]
if chunk_text:
chunks.append({
"text": chunk_text,
"index": len(chunks),
"start_pos": i,
"end_pos": i + len(chunk_text)
})
print(f"Created {len(chunks)} text chunks")
return chunks
Creating Embeddings
def create_embeddings(texts, model="BAAI/bge-en-icl"):
if not texts:
return []
batch_size = 100
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
response = client.embeddings.create(
model=model,
input=batch
)
batch_embeddings = [item.embedding for item in response.data]
all_embeddings.extend(batch_embeddings)
return all_embeddings
Knowledge Graph Construction
def extract_concepts(text):
system_message = """Extract key concepts and entities from the provided text.
Return ONLY a list of 5-10 key terms, entities, or concepts that are most important in this text.
Format your response as a JSON array of strings."""
response = client.chat.completions.create(
model="meta-llama/Llama-3.2-3B-Instruct",
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": f"Extract key concepts from:\n\n{text[:3000]}"}
],
temperature=0.0,
response_format={"type": "json_object"}
)
try:
concepts_json = json.loads(response.choices[0].message.content)
concepts = concepts_json.get("concepts", [])
if not concepts and "concepts" not in concepts_json:
for key, value in concepts_json.items():
if isinstance(value, list):
concepts = value
break
return concepts
except (json.JSONDecodeError, AttributeError):
content = response.choices[0].message.content
matches = re.findall(r'\[(.*?)\]', content, re.DOTALL)
if matches:
items = re.findall(r'"([^"]*)"', matches[0])
return items
return []
def build_knowledge_graph(chunks):
print("Building knowledge graph...")
graph = nx.Graph()
texts = [chunk["text"] for chunk in chunks]
print("Creating embeddings for chunks...")
embeddings = create_embeddings(texts)
print("Adding nodes to the graph...")
for i, chunk in enumerate(chunks):
print(f"Extracting concepts for chunk {i+1}/{len(chunks)}...")
concepts = extract_concepts(chunk["text"])
graph.add_node(i,
text=chunk["text"],
concepts=concepts,
embedding=embeddings[i])
print("Creating edges between nodes...")
for i in range(len(chunks)):
node_concepts = set(graph.nodes[i]["concepts"])
for j in range(i + 1, len(chunks)):
other_concepts = set(graph.nodes[j]["concepts"])
shared_concepts = node_concepts.intersection(other_concepts)
if shared_concepts:
similarity = np.dot(embeddings[i], embeddings[j]) / (np.linalg.norm(embeddings[i]) * np.linalg.norm(embeddings[j]))
concept_score = len(shared_concepts) / min(len(node_concepts), len(other_concepts))
edge_weight = 0.7 * similarity + 0.3 * concept_score
if edge_weight > 0.6:
graph.add_edge(i, j,
weight=edge_weight,
similarity=similarity,
shared_concepts=list(shared_concepts))
print(f"Knowledge graph built with {graph.number_of_nodes()} nodes and {graph.number_of_edges()} edges")
return graph, embeddings
Graph Traversal and Query Processing
def traverse_graph(query, graph, embeddings, top_k=5, max_depth=3):
print(f"Traversing graph for query: {query}")
query_embedding = create_embeddings(query)
similarities = []
for i, node_embedding in enumerate(embeddings):
similarity = np.dot(query_embedding, node_embedding) / (np.linalg.norm(query_embedding) * np.linalg.norm(node_embedding))
similarities.append((i, similarity))
similarities.sort(key=lambda x: x[1], reverse=True)
starting_nodes = [node for node, _ in similarities[:top_k]]
print(f"Starting traversal from {len(starting_nodes)} nodes")
visited = set()
traversal_path = []
results = []
queue = []
for node in starting_nodes:
heapq.heappush(queue, (-similarities[node][1], node))
while queue and len(results) < (top_k * 3):
_, node = heapq.heappop(queue)
if node in visited:
continue
visited.add(node)
traversal_path.append(node)
results.append({
"text": graph.nodes[node]["text"],
"concepts": graph.nodes[node]["concepts"],
"node_id": node
})
if len(traversal_path) < max_depth:
neighbors = [(neighbor, graph[node][neighbor]["weight"])
for neighbor in graph.neighbors(node)
if neighbor not in visited]
for neighbor, weight in sorted(neighbors, key=lambda x: x[1], reverse=True):
heapq.heappush(queue, (-weight, neighbor))
print(f"Graph traversal found {len(results)} relevant chunks")
return results, traversal_path
Response Generation
def generate_response(query, context_chunks):
context_texts = [chunk["text"] for chunk in context_chunks]
combined_context = "\n\n---\n\n".join(context_texts)
max_context = 14000
if len(combined_context) > max_context:
combined_context = combined_context[:max_context] + "... [truncated]"
system_message = """You are a helpful AI assistant. Answer the user's question based on the provided context.
If the information is not in the context, say so. Refer to specific parts of the context in your answer when possible."""
response = client.chat.completions.create(
model="meta-llama/Llama-3.2-3B-Instruct",
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": f"Context:\n{combined_context}\n\nQuestion: {query}"}
],
temperature=0.2
)
return response.choices[0].message.content
Visualization
def visualize_graph_traversal(graph, traversal_path):
plt.figure(figsize=(12, 10))
node_color = ['lightblue'] * graph.number_of_nodes()
for node in traversal_path:
node_color[node] = 'lightgreen'
if traversal_path:
node_color[traversal_path[0]] = 'green'
node_color[traversal_path[-1]] = 'red'
pos = nx.spring_layout(graph, k=0.5, iterations=50, seed=42)
nx.draw_networkx_nodes(graph, pos, node_color=node_color, node_size=500, alpha=0.8)
for u, v, data in graph.edges(data=True):
weight = data.get('weight', 1.0)
nx.draw_networkx_edges(graph, pos, edgelist=[(u, v)], width=weight*2, alpha=0.6)
traversal_edges = [(traversal_path[i], traversal_path[i+1])
for i in range(len(traversal_path)-1)]
nx.draw_networkx_edges(graph, pos, edgelist=traversal_edges,
width=3, alpha=0.8, edge_color='red',
style='dashed', arrows=True)
labels = {}
for node in graph.nodes():
concepts = graph.nodes[node]['concepts']
label = concepts[0] if concepts else f"Node {node}"
labels[node] = f"{node}: {label}"
nx.draw_networkx_labels(graph, pos, labels=labels, font_size=8)
plt.title("Knowledge Graph with Traversal Path")
plt.axis('off')
plt.tight_layout()
plt.show()
Complete Graph RAG Pipeline
def graph_rag_pipeline(pdf_path, query, chunk_size=1000, chunk_overlap=200, top_k=3):
text = extract_text_from_pdf(pdf_path)
chunks = chunk_text(text, chunk_size, chunk_overlap)
graph, embeddings = build_knowledge_graph(chunks)
relevant_chunks, traversal_path = traverse_graph(query, graph, embeddings, top_k)
response = generate_response(query, relevant_chunks)
visualize_graph_traversal(graph, traversal_path)
return {
"query": query,
"response": response,
"relevant_chunks": relevant_chunks,
"traversal_path": traversal_path,
"graph": graph
}
Evaluation Function
def evaluate_graph_rag(pdf_path, test_queries, reference_answers=None):
text = extract_text_from_pdf(pdf_path)
chunks = chunk_text(text)
graph, embeddings = build_knowledge_graph(chunks)
results = []
for i, query in enumerate(test_queries):
print(f"\n\n=== Evaluating Query {i+1}/{len(test_queries)} ===")
print(f"Query: {query}")
relevant_chunks, traversal_path = traverse_graph(query, graph, embeddings)
response = generate_response(query, relevant_chunks)
reference = None
comparison = None
if reference_answers and i < len(reference_answers):
reference = reference_answers[i]
comparison = compare_with_reference(response, reference, query)
results.append({
"query": query,
"response": response,
"reference_answer": reference,
"comparison": comparison,
"traversal_path_length": len(traversal_path),
"relevant_chunks_count": len(relevant_chunks)
})
print(f"\nResponse: {response}\n")
if comparison:
print(f"Comparison: {comparison}\n")
return {
"results": results,
"graph_stats": {
"nodes": graph.number_of_nodes(),
"edges": graph.number_of_edges(),
"avg_degree": sum(dict(graph.degree()).values()) / graph.number_of_nodes()
}
}
def compare_with_reference(response, reference, query):
system_message = """Compare the AI-generated response with the reference answer.
Evaluate based on: correctness, completeness, and relevance to the query.
Provide a brief analysis (2-3 sentences) of how well the generated response matches the reference."""
prompt = f"""
Query: {query}
AI-generated response:
{response}
Reference answer:
{reference}
How well does the AI response match the reference?
"""
comparison = client.chat.completions.create(
model="meta-llama/Llama-3.2-3B-Instruct",
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": prompt}
],
temperature=0.0
)
return comparison.choices[0].message.content
Evaluation of Graph RAG on a Sample PDF Document
pdf_path = "data/AI_Information.pdf"
test_queries = [
"What are the key applications of transformers in natural language processing?",
]
reference_answers = [
"Transformer models have revolutionized natural language processing with applications including machine translation, text summarization, question answering, sentiment analysis, and text generation. They excel at capturing long-range dependencies in text and have become the foundation for models like BERT, GPT, and T5.",
]
evaluation_results = evaluate_graph_rag(
pdf_path=pdf_path,
test_queries=test_queries,
reference_answers=reference_answers
)