June 27, 2025
35 min read
Multi-Modal RAG with Image Captioning
RAG
NLP
AI
Multi-modal
Image Captioning
Multi-Modal RAG with Image Captioning
In this notebook, I implement a Multi-Modal RAG system that extracts both text and images from documents, generates captions for images, and uses both content types to respond to queries. This approach enhances traditional RAG by incorporating visual information into the knowledge base.
Traditional RAG systems only work with text, but many documents contain crucial information in images, charts, and tables. By captioning these visual elements and incorporating them into our retrieval system, we can:
- Access information locked in figures and diagrams
- Understand tables and charts that complement the text
- Create a more comprehensive knowledge base
- Answer questions that rely on visual data
Setting Up the Environment
We begin by importing necessary libraries.
import os
import io
import numpy as np
import json
import fitz
from PIL import Image
from openai import OpenAI
import base64
import re
import tempfile
import shutil
Setting Up the OpenAI API Client
We initialize the OpenAI client to generate embeddings and responses.
client = OpenAI(
base_url="https://api.studio.nebius.com/v1/",
api_key=os.getenv("OPENAI_API_KEY")
)
Document Processing Functions
def extract_content_from_pdf(pdf_path, output_dir=None):
temp_dir = None
if output_dir is None:
temp_dir = tempfile.mkdtemp()
output_dir = temp_dir
else:
os.makedirs(output_dir, exist_ok=True)
text_data = []
image_paths = []
print(f"Extracting content from {pdf_path}...")
try:
with fitz.open(pdf_path) as pdf_file:
for page_number in range(len(pdf_file)):
page = pdf_file[page_number]
text = page.get_text().strip()
if text:
text_data.append({
"content": text,
"metadata": {
"source": pdf_path,
"page": page_number + 1,
"type": "text"
}
})
image_list = page.get_images(full=True)
for img_index, img in enumerate(image_list):
xref = img[0]
base_image = pdf_file.extract_image(xref)
if base_image:
image_bytes = base_image["image"]
image_ext = base_image["ext"]
img_filename = f"page_{page_number+1}_img_{img_index+1}.{image_ext}"
img_path = os.path.join(output_dir, img_filename)
with open(img_path, "wb") as img_file:
img_file.write(image_bytes)
image_paths.append({
"path": img_path,
"metadata": {
"source": pdf_path,
"page": page_number + 1,
"image_index": img_index + 1,
"type": "image"
}
})
print(f"Extracted {len(text_data)} text segments and {len(image_paths)} images")
return text_data, image_paths
except Exception as e:
print(f"Error extracting content: {e}")
if temp_dir and os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
raise
Chunking Text Content
def chunk_text(text_data, chunk_size=1000, overlap=200):
chunked_data = []
for item in text_data:
text = item["content"]
metadata = item["metadata"]
if len(text) < chunk_size / 2:
chunked_data.append({
"content": text,
"metadata": metadata
})
continue
chunks = []
for i in range(0, len(text), chunk_size - overlap):
chunk = text[i:i + chunk_size]
if chunk:
chunks.append(chunk)
for i, chunk in enumerate(chunks):
chunk_metadata = metadata.copy()
chunk_metadata["chunk_index"] = i
chunk_metadata["chunk_count"] = len(chunks)
chunked_data.append({
"content": chunk,
"metadata": chunk_metadata
})
print(f"Created {len(chunked_data)} text chunks")
return chunked_data
Image Captioning with OpenAI Vision
def encode_image(image_path):
with open(image_path, "rb") as image_file:
encoded_image = base64.b64encode(image_file.read())
return encoded_image.decode('utf-8')
def generate_image_caption(image_path):
if not os.path.exists(image_path):
return "Error: Image file not found"
try:
Image.open(image_path)
base64_image = encode_image(image_path)
response = client.chat.completions.create(
model="llava-hf/llava-1.5-7b-hf",
messages=[
{
"role": "system",
"content": "You are an assistant specialized in describing images from academic papers. "
"Provide detailed captions for the image that capture key information. "
"If the image contains charts, tables, or diagrams, describe their content and purpose clearly. "
"Your caption should be optimized for future retrieval when people ask questions about this content."
},
{
"role": "user",
"content": [
{"type": "text", "text": "Describe this image in detail, focusing on its academic content:"},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}
]
},
],
max_tokens=300
)
caption = response.choices[0].message.content
return caption
except Exception as e:
return f"Error generating caption: {str(e)}"
def process_images(image_paths):
image_data = []
print(f"Generating captions for {len(image_paths)} images...")
for i, img_item in enumerate(image_paths):
print(f"Processing image {i+1}/{len(image_paths)}...")
img_path = img_item["path"]
metadata = img_item["metadata"]
caption = generate_image_caption(img_path)
image_data.append({
"content": caption,
"metadata": metadata,
"image_path": img_path
})
return image_data
Simple Vector Store Implementation
class MultiModalVectorStore:
def __init__(self):
self.vectors = []
self.contents = []
self.metadata = []
def add_item(self, content, embedding, metadata=None):
self.vectors.append(np.array(embedding))
self.contents.append(content)
self.metadata.append(metadata or {})
def add_items(self, items, embeddings):
for item, embedding in zip(items, embeddings):
self.add_item(
content=item["content"],
embedding=embedding,
metadata=item.get("metadata", {})
)
def similarity_search(self, query_embedding, k=5):
if not self.vectors:
return []
query_vector = np.array(query_embedding)
similarities = []
for i, vector in enumerate(self.vectors):
similarity = np.dot(query_vector, vector) / (np.linalg.norm(query_vector) * np.linalg.norm(vector))
similarities.append((i, similarity))
similarities.sort(key=lambda x: x[1], reverse=True)
results = []
for i in range(min(k, len(similarities))):
idx, score = similarities[i]
results.append({
"content": self.contents[idx],
"metadata": self.metadata[idx],
"similarity": float(score)
})
return results
Creating Embeddings
def create_embeddings(texts, model="BAAI/bge-en-icl"):
if not texts:
return []
batch_size = 100
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
response = client.embeddings.create(
model=model,
input=batch
)
batch_embeddings = [item.embedding for item in response.data]
all_embeddings.extend(batch_embeddings)
return all_embeddings
Complete Processing Pipeline
def process_document(pdf_path, chunk_size=1000, chunk_overlap=200):
image_dir = "extracted_images"
os.makedirs(image_dir, exist_ok=True)
text_data, image_paths = extract_content_from_pdf(pdf_path, image_dir)
chunked_text = chunk_text(text_data, chunk_size, chunk_overlap)
image_data = process_images(image_paths)
all_items = chunked_text + image_data
contents = [item["content"] for item in all_items]
print("Creating embeddings for all content...")
embeddings = create_embeddings(contents)
vector_store = MultiModalVectorStore()
vector_store.add_items(all_items, embeddings)
doc_info = {
"text_count": len(chunked_text),
"image_count": len(image_data),
"total_items": len(all_items),
}
print(f"Added {len(all_items)} items to vector store ({len(chunked_text)} text chunks, {len(image_data)} image captions)")
return vector_store, doc_info
Query Processing and Response Generation
def query_multimodal_rag(query, vector_store, k=5):
print(f"\n=== Processing query: {query} ===\n")
query_embedding = create_embeddings(query)
results = vector_store.similarity_search(query_embedding, k=k)
text_results = [r for r in results if r["metadata"].get("type") == "text"]
image_results = [r for r in results if r["metadata"].get("type") == "image"]
print(f"Retrieved {len(results)} relevant items ({len(text_results)} text, {len(image_results)} image captions)")
response = generate_response(query, results)
return {
"query": query,
"results": results,
"response": response,
"text_results_count": len(text_results),
"image_results_count": len(image_results)
}
def generate_response(query, results):
context = ""
for i, result in enumerate(results):
content_type = "Text" if result["metadata"].get("type") == "text" else "Image caption"
page_num = result["metadata"].get("page", "unknown")
context += f"[{content_type} from page {page_num}]\n"
context += result["content"]
context += "\n\n"
system_message = """You are an AI assistant specializing in answering questions about documents
that contain both text and images. You have been given relevant text passages and image captions
from the document. Use this information to provide a comprehensive, accurate response to the query.
If information comes from an image or chart, mention this in your answer.
If the retrieved information doesn't fully answer the query, acknowledge the limitations."""
user_message = f"""Query: {query}
Retrieved content:
{context}
Please answer the query based on the retrieved content.
"""
response = client.chat.completions.create(
model="meta-llama/Llama-3.2-3B-Instruct",
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": user_message}
],
temperature=0.1
)
return response.choices[0].message.content
Evaluation Against Text-Only RAG
def build_text_only_store(pdf_path, chunk_size=1000, chunk_overlap=200):
text_data, _ = extract_content_from_pdf(pdf_path, None)
chunked_text = chunk_text(text_data, chunk_size, chunk_overlap)
contents = [item["content"] for item in chunked_text]
print("Creating embeddings for text-only content...")
embeddings = create_embeddings(contents)
vector_store = MultiModalVectorStore()
vector_store.add_items(chunked_text, embeddings)
print(f"Added {len(chunked_text)} text items to text-only vector store")
return vector_store
def evaluate_multimodal_vs_textonly(pdf_path, test_queries, reference_answers=None):
print("=== EVALUATING MULTI-MODAL RAG VS TEXT-ONLY RAG ===\n")
multimodal_vector_store, _ = process_document(pdf_path)
text_only_vector_store = build_text_only_store(pdf_path)
results = []
for i, query in enumerate(test_queries):
print(f"\n=== Evaluating Query {i+1}: {query} ===\n")
print("Running multi-modal RAG...")
multimodal_result = query_multimodal_rag(query, multimodal_vector_store)
print("\nRunning text-only RAG...")
text_only_result = query_multimodal_rag(query, text_only_vector_store)
result = {
"query": query,
"multimodal_result": multimodal_result,
"text_only_result": text_only_result,
}
if reference_answers and i < len(reference_answers):
result["reference_answer"] = reference_answers[i]
results.append(result)
overall_analysis = generate_overall_analysis(results)
return {
"results": results,
"overall_analysis": overall_analysis
}
def generate_overall_analysis(results):
system_prompt = """Our analysis compares the performance of multi-modal RAG (text + images) and text-only RAG across multiple test queries. We evaluate the strengths and weaknesses of each approach, focusing on the types of queries where multi-modal RAG outperforms text-only, the advantages of incorporating image information, and the limitations of the multi-modal approach."""
user_prompt = """Based on the following results, provide an overall analysis comparing Multi-Modal RAG and Text-Only RAG:
"""
for i, result in enumerate(results):
user_prompt += f"Query {i+1}: {result['query']}\n"
user_prompt += f"Multi-modal Response: {result['multimodal_result']['response']}\n"
user_prompt += f"Text-only Response: {result['text_only_result']['response']}\n"
if 'reference_answer' in result:
user_prompt += f"Reference Answer: {result['reference_answer']}\n"
user_prompt += "\n"
response = client.chat.completions.create(
model="meta-llama/Llama-3.2-3B-Instruct",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
temperature=0.1
)
return response.choices[0].message.content
Evaluating the Multi-Modal RAG System
pdf_path = "data/attention_is_all_you_need.pdf"
test_queries = [
"What is the BLEU score of the Transformer (base model)?",
]
reference_answers = [
"The Transformer (base model) achieves a BLEU score of 27.3 on the WMT 2014 English-to-German translation task and 38.1 on the WMT 2014 English-to-French translation task.",
]
evaluation_results = evaluate_multimodal_vs_textonly(
pdf_path=pdf_path,
test_queries=test_queries,
reference_answers=reference_answers
)
print("\n=== OVERALL ANALYSIS ===\n")
print(evaluation_results["overall_analysis"])