Oreoluwa
Multi-Modal RAG with Image Captioning
June 27, 2025
35 min read

Multi-Modal RAG with Image Captioning

RAG
NLP
AI
Multi-modal
Image Captioning

Multi-Modal RAG with Image Captioning

In this notebook, I implement a Multi-Modal RAG system that extracts both text and images from documents, generates captions for images, and uses both content types to respond to queries. This approach enhances traditional RAG by incorporating visual information into the knowledge base.

Traditional RAG systems only work with text, but many documents contain crucial information in images, charts, and tables. By captioning these visual elements and incorporating them into our retrieval system, we can:

  • Access information locked in figures and diagrams
  • Understand tables and charts that complement the text
  • Create a more comprehensive knowledge base
  • Answer questions that rely on visual data

Setting Up the Environment

We begin by importing necessary libraries.

import os
import io
import numpy as np
import json
import fitz
from PIL import Image
from openai import OpenAI
import base64
import re
import tempfile
import shutil

Setting Up the OpenAI API Client

We initialize the OpenAI client to generate embeddings and responses.

client = OpenAI(
    base_url="https://api.studio.nebius.com/v1/",
    api_key=os.getenv("OPENAI_API_KEY")
)

Document Processing Functions

def extract_content_from_pdf(pdf_path, output_dir=None):
    temp_dir = None
    if output_dir is None:
        temp_dir = tempfile.mkdtemp()
        output_dir = temp_dir
    else:
        os.makedirs(output_dir, exist_ok=True)
    text_data = []
    image_paths = []
    print(f"Extracting content from {pdf_path}...")
    try:
        with fitz.open(pdf_path) as pdf_file:
            for page_number in range(len(pdf_file)):
                page = pdf_file[page_number]
                text = page.get_text().strip()
                if text:
                    text_data.append({
                        "content": text,
                        "metadata": {
                            "source": pdf_path,
                            "page": page_number + 1,
                            "type": "text"
                        }
                    })
                image_list = page.get_images(full=True)
                for img_index, img in enumerate(image_list):
                    xref = img[0]
                    base_image = pdf_file.extract_image(xref)
                    if base_image:
                        image_bytes = base_image["image"]
                        image_ext = base_image["ext"]
                        img_filename = f"page_{page_number+1}_img_{img_index+1}.{image_ext}"
                        img_path = os.path.join(output_dir, img_filename)
                        with open(img_path, "wb") as img_file:
                            img_file.write(image_bytes)
                        image_paths.append({
                            "path": img_path,
                            "metadata": {
                                "source": pdf_path,
                                "page": page_number + 1,
                                "image_index": img_index + 1,
                                "type": "image"
                            }
                        })
        print(f"Extracted {len(text_data)} text segments and {len(image_paths)} images")
        return text_data, image_paths
    except Exception as e:
        print(f"Error extracting content: {e}")
        if temp_dir and os.path.exists(temp_dir):
            shutil.rmtree(temp_dir)
        raise

Chunking Text Content

def chunk_text(text_data, chunk_size=1000, overlap=200):
    chunked_data = []
    for item in text_data:
        text = item["content"]
        metadata = item["metadata"]
        if len(text) < chunk_size / 2:
            chunked_data.append({
                "content": text,
                "metadata": metadata
            })
            continue
        chunks = []
        for i in range(0, len(text), chunk_size - overlap):
            chunk = text[i:i + chunk_size]
            if chunk:
                chunks.append(chunk)
        for i, chunk in enumerate(chunks):
            chunk_metadata = metadata.copy()
            chunk_metadata["chunk_index"] = i
            chunk_metadata["chunk_count"] = len(chunks)
            chunked_data.append({
                "content": chunk,
                "metadata": chunk_metadata
            })
    print(f"Created {len(chunked_data)} text chunks")
    return chunked_data

Image Captioning with OpenAI Vision

def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        encoded_image = base64.b64encode(image_file.read())
        return encoded_image.decode('utf-8')

def generate_image_caption(image_path):
    if not os.path.exists(image_path):
        return "Error: Image file not found"
    try:
        Image.open(image_path)
        base64_image = encode_image(image_path)
        response = client.chat.completions.create(
            model="llava-hf/llava-1.5-7b-hf",
            messages=[
                {
                    "role": "system",
                    "content": "You are an assistant specialized in describing images from academic papers. "
                               "Provide detailed captions for the image that capture key information. "
                               "If the image contains charts, tables, or diagrams, describe their content and purpose clearly. "
                               "Your caption should be optimized for future retrieval when people ask questions about this content."
                },
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": "Describe this image in detail, focusing on its academic content:"},
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": f"data:image/jpeg;base64,{base64_image}"
                            }
                        }
                    ]
                },
            ],
            max_tokens=300
        )
        caption = response.choices[0].message.content
        return caption
    except Exception as e:
        return f"Error generating caption: {str(e)}"

def process_images(image_paths):
    image_data = []
    print(f"Generating captions for {len(image_paths)} images...")
    for i, img_item in enumerate(image_paths):
        print(f"Processing image {i+1}/{len(image_paths)}...")
        img_path = img_item["path"]
        metadata = img_item["metadata"]
        caption = generate_image_caption(img_path)
        image_data.append({
            "content": caption,
            "metadata": metadata,
            "image_path": img_path
        })
    return image_data

Simple Vector Store Implementation

class MultiModalVectorStore:
    def __init__(self):
        self.vectors = []
        self.contents = []
        self.metadata = []
    
    def add_item(self, content, embedding, metadata=None):
        self.vectors.append(np.array(embedding))
        self.contents.append(content)
        self.metadata.append(metadata or {})
    
    def add_items(self, items, embeddings):
        for item, embedding in zip(items, embeddings):
            self.add_item(
                content=item["content"],
                embedding=embedding,
                metadata=item.get("metadata", {})
            )
    
    def similarity_search(self, query_embedding, k=5):
        if not self.vectors:
            return []
        query_vector = np.array(query_embedding)
        similarities = []
        for i, vector in enumerate(self.vectors):
            similarity = np.dot(query_vector, vector) / (np.linalg.norm(query_vector) * np.linalg.norm(vector))
            similarities.append((i, similarity))
        similarities.sort(key=lambda x: x[1], reverse=True)
        results = []
        for i in range(min(k, len(similarities))):
            idx, score = similarities[i]
            results.append({
                "content": self.contents[idx],
                "metadata": self.metadata[idx],
                "similarity": float(score)
            })
        return results

Creating Embeddings

def create_embeddings(texts, model="BAAI/bge-en-icl"):
    if not texts:
        return []
    batch_size = 100
    all_embeddings = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i + batch_size]
        response = client.embeddings.create(
            model=model,
            input=batch
        )
        batch_embeddings = [item.embedding for item in response.data]
        all_embeddings.extend(batch_embeddings)
    return all_embeddings

Complete Processing Pipeline

def process_document(pdf_path, chunk_size=1000, chunk_overlap=200):
    image_dir = "extracted_images"
    os.makedirs(image_dir, exist_ok=True)
    text_data, image_paths = extract_content_from_pdf(pdf_path, image_dir)
    chunked_text = chunk_text(text_data, chunk_size, chunk_overlap)
    image_data = process_images(image_paths)
    all_items = chunked_text + image_data
    contents = [item["content"] for item in all_items]
    print("Creating embeddings for all content...")
    embeddings = create_embeddings(contents)
    vector_store = MultiModalVectorStore()
    vector_store.add_items(all_items, embeddings)
    doc_info = {
        "text_count": len(chunked_text),
        "image_count": len(image_data),
        "total_items": len(all_items),
    }
    print(f"Added {len(all_items)} items to vector store ({len(chunked_text)} text chunks, {len(image_data)} image captions)")
    return vector_store, doc_info

Query Processing and Response Generation

def query_multimodal_rag(query, vector_store, k=5):
    print(f"\n=== Processing query: {query} ===\n")
    query_embedding = create_embeddings(query)
    results = vector_store.similarity_search(query_embedding, k=k)
    text_results = [r for r in results if r["metadata"].get("type") == "text"]
    image_results = [r for r in results if r["metadata"].get("type") == "image"]
    print(f"Retrieved {len(results)} relevant items ({len(text_results)} text, {len(image_results)} image captions)")
    response = generate_response(query, results)
    return {
        "query": query,
        "results": results,
        "response": response,
        "text_results_count": len(text_results),
        "image_results_count": len(image_results)
    }

def generate_response(query, results):
    context = ""
    for i, result in enumerate(results):
        content_type = "Text" if result["metadata"].get("type") == "text" else "Image caption"
        page_num = result["metadata"].get("page", "unknown")
        context += f"[{content_type} from page {page_num}]\n"
        context += result["content"]
        context += "\n\n"
    system_message = """You are an AI assistant specializing in answering questions about documents 
    that contain both text and images. You have been given relevant text passages and image captions 
    from the document. Use this information to provide a comprehensive, accurate response to the query.
    If information comes from an image or chart, mention this in your answer.
    If the retrieved information doesn't fully answer the query, acknowledge the limitations."""
    user_message = f"""Query: {query}

    Retrieved content:
    {context}

    Please answer the query based on the retrieved content.
    """
    response = client.chat.completions.create(
        model="meta-llama/Llama-3.2-3B-Instruct",
        messages=[
            {"role": "system", "content": system_message},
            {"role": "user", "content": user_message}
        ],
        temperature=0.1
    )
    return response.choices[0].message.content

Evaluation Against Text-Only RAG

def build_text_only_store(pdf_path, chunk_size=1000, chunk_overlap=200):
    text_data, _ = extract_content_from_pdf(pdf_path, None)
    chunked_text = chunk_text(text_data, chunk_size, chunk_overlap)
    contents = [item["content"] for item in chunked_text]
    print("Creating embeddings for text-only content...")
    embeddings = create_embeddings(contents)
    vector_store = MultiModalVectorStore()
    vector_store.add_items(chunked_text, embeddings)
    print(f"Added {len(chunked_text)} text items to text-only vector store")
    return vector_store

def evaluate_multimodal_vs_textonly(pdf_path, test_queries, reference_answers=None):
    print("=== EVALUATING MULTI-MODAL RAG VS TEXT-ONLY RAG ===\n")
    multimodal_vector_store, _ = process_document(pdf_path)
    text_only_vector_store = build_text_only_store(pdf_path)
    results = []
    for i, query in enumerate(test_queries):
        print(f"\n=== Evaluating Query {i+1}: {query} ===\n")
        print("Running multi-modal RAG...")
        multimodal_result = query_multimodal_rag(query, multimodal_vector_store)
        print("\nRunning text-only RAG...")
        text_only_result = query_multimodal_rag(query, text_only_vector_store)
        result = {
            "query": query,
            "multimodal_result": multimodal_result,
            "text_only_result": text_only_result,
        }
        if reference_answers and i < len(reference_answers):
            result["reference_answer"] = reference_answers[i]
        results.append(result)
    overall_analysis = generate_overall_analysis(results)
    return {
        "results": results,
        "overall_analysis": overall_analysis
    }

def generate_overall_analysis(results):
    system_prompt = """Our analysis compares the performance of multi-modal RAG (text + images) and text-only RAG across multiple test queries. We evaluate the strengths and weaknesses of each approach, focusing on the types of queries where multi-modal RAG outperforms text-only, the advantages of incorporating image information, and the limitations of the multi-modal approach."""
    user_prompt = """Based on the following results, provide an overall analysis comparing Multi-Modal RAG and Text-Only RAG:

    """
    for i, result in enumerate(results):
        user_prompt += f"Query {i+1}: {result['query']}\n"
        user_prompt += f"Multi-modal Response: {result['multimodal_result']['response']}\n"
        user_prompt += f"Text-only Response: {result['text_only_result']['response']}\n"
        if 'reference_answer' in result:
            user_prompt += f"Reference Answer: {result['reference_answer']}\n"
        user_prompt += "\n"
    response = client.chat.completions.create(
        model="meta-llama/Llama-3.2-3B-Instruct",
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ],
        temperature=0.1
    )
    return response.choices[0].message.content

Evaluating the Multi-Modal RAG System

pdf_path = "data/attention_is_all_you_need.pdf"
test_queries = [
    "What is the BLEU score of the Transformer (base model)?",
]
reference_answers = [
    "The Transformer (base model) achieves a BLEU score of 27.3 on the WMT 2014 English-to-German translation task and 38.1 on the WMT 2014 English-to-French translation task.",
]

evaluation_results = evaluate_multimodal_vs_textonly(
    pdf_path=pdf_path,
    test_queries=test_queries,
    reference_answers=reference_answers
)

print("\n=== OVERALL ANALYSIS ===\n")
print(evaluation_results["overall_analysis"])