
Posted by Jesse JCharis
Feb. 23, 2025, 5:33 a.m.
Building a Retrieval-Augmented Generation (RAG) System with Scikit-Learn Vectorizers
While modern LLMs excel at text generation, they often lack access to specific or up-to-date information. Retrieval-Augmented Generation (RAG) solves this by combining document retrieval with text generation. In this tutorial, we'll implement a RAG system using Scikit-Learn's vectorizers for text processing and similarity search instead of dedicated vector databases like FAISS.
Why Use Scikit-Learn for RAG?
Scikit-Learn offers lightweight TF-IDF/BOW vectorizers and similarity utilities that work well for:
- Small-to-medium datasets
- Prototyping RAG systems
- Projects requiring minimal dependencies
- Keyword-driven retrieval tasks
While less sophisticated than transformer-based embeddings, these traditional NLP tools remain effective for many use cases.
Implementation Walkthrough
1. Install Dependencies
!pip install scikit-learn transformers datasets
2. Prepare Sample Data
documents = [
"The Eiffel Tower was completed in 1889 for the World's Fair",
"Mount Everest stands at 8,848 meters above sea level",
"Python was created by Guido van Rossum and first released in 1991",
"The Great Wall of China stretches over 21,000 kilometers",
"Albert Einstein developed the theory of relativity"
]
3. Create TF-IDF Vectorizer and Document Store
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
vectorizer = TfidfVectorizer(stop_words='english')
doc_vectors = vectorizer.fit_transform(documents)
4. Build Retriever Component
def retrieve(query: str, k=2) -> list:
# Convert query to TF-IDF vector
query_vec = vectorizer.transform([query])
# Compute cosine similarities
similarities = cosine_similarity(query_vec, doc_vectors)
# Get top-k document indices
top_indices = similarities.argsort()[0][-k:][::-1]
return [documents[i] for i in top_indices]
5. Set Up Generator Model
from transformers import pipeline
generator = pipeline(
"text2text-generation",
model="google/flan-t5-base" # Good at following instructions
)
6. Assemble RAG Pipeline
def rag_pipeline(query: str) -> str:
# Retrieve relevant context
context_docs = retrieve(query)
# Format input for generator
context = "\n".join(context_docs)
prompt = f"Answer this question: {query}\nUsing this context:\n{context}"
# Generate response
result = generator(prompt, max_length=200)
return result[0]['generated_text']
Practical Examples
Example 1: Historical Fact Retrieval
question = "When was Python first released?"
print(rag_pipeline(question))
Output:
"Python was first released in 1991"
Retrieved Context:
- "Python was created by Guido van Rossum..."
- "The Eiffel Tower was completed in 1889..."
Example 2: Scientific Fact Query
question = "Who developed the theory of relativity?"
print(rag_pipeline(question))
Output:
"Albert Einstein developed the theory of relativity"
Retrieved Context:
- "Albert Einstein developed..."
- "Mount Everest stands at..."
Performance Optimization Tips
- Preprocessing Matters:
vectorizer = TfidfVectorizer(
stop_words='english',
ngram_range=(1,2), # Capture phrases
max_features=5000 # Control memory usage
)
- Hybrid Search: Combine TF-IDF with BM25 scoring:
from sklearn.feature_extraction.text import TfidfVectorizer
class BM25Vectorizer(TfidfVectorizer):
def __init__(self, k=1.2, b=0.75, **kwargs):
super().__init__(**kwargs)
self.k = k
self.b = b
def transform(self, X):
tfidf = super().transform(X)
len_X = X.apply(lambda x: len(x.split()))
avgdl = len_X.mean()
values = tfidf.data * (self.k + 1) / (
tfidf.data + self.k * (1 - self.b + self.b * len_X / avgdl)
)
return type(tfidf)((values,) + tfidf.indices + tfidf.indptr)
- Caching Vectors: Store precomputed document vectors to speed up retrieval:
import joblib
# Save vectors
joblib.dump(doc_vectors, 'doc_vectors.pkl')
# Load later
doc_vectors = joblib.load('doc_vectors.pkl')
When to Use This Approach
Scenario | Recommended For |
---|---|
Small document collections | <100k documents |
Keyword-focused queries | Fact retrieval vs semantic search |
Prototyping | Quick implementation |
Resource-constrained envs | No GPU available |
Limitations and Solutions
Semantic Understanding:
Problem: TF-IDF struggles with synonymy
Fix: Add synonym expansion during preprocessingScalability:
Problem: Full matrix comparisons become slow
Fix: Use approximate nearest neighbors (ANN) fromannoy
ornmslib
Context Quality:
Problem: Irrelevant retrieved documents
Fix: Add reranking step using cross-encoders
Full Enhanced Implementation
from sklearn.neighbors import NearestNeighbors
class EnhancedRAG:
def __init__(self):
self.vectorizer = TfidfVectorizer(
stop_words='english',
ngram_range=(1,2),
max_features=5000,
analyzer='word'
)
self.generator = pipeline(
"text2text-generation",
model="google/flan-t5-base"
)
def index(self, documents):
self.documents = documents
self.doc_vectors = self.vectorizer.fit_transform(documents)
# Create ANN index for faster search
self.nn_index = NearestNeighbors(n_neighbors=2)
self.nn_index.fit(self.doc_vectors)
def retrieve(self, query):
query_vec = self.vectorizer.transform([query])
_, indices = self.nn_index.kneighbors(query_vec)
return [self.documents[i] for i in indices[0]]
def generate(self, query):
context = "\n".join(self.retrieve(query))
prompt = f"Question: {query}\nContext:\n{context}\nAnswer:"
return self.generator(prompt)[0]['generated_text']
Conclusion
This Scikit-Learn-based RAG implementation demonstrates how traditional NLP tools can effectively augment LLMs for fact-based question answering. While less sophisticated than transformer-based retrieval systems using FAISS or Pinecone, it offers:
- Minimal setup requirements
- Fast implementation
- Easy debugging
For production systems handling millions of documents or requiring semantic understanding consider combining this approach with:
- Sentence-transformers embeddings
- Dedicated vector databases
- Advanced reranking models
Happy Coding
Jesus Saves
No tags associated with this blog post.
NLP Analysis
- Sentiment: positive
- Subjectivity: positive
- Emotions: joy
- Probability: {'anger': 1.053622275146042e-169, 'disgust': 4.321893733097591e-196, 'fear': 3.0079495480334846e-173, 'joy': 1.0, 'neutral': 0.0, 'sadness': 5.348832288736802e-231, 'shame': 9.636623733657151e-299, 'surprise': 2.676383332260534e-164}