This is the capstone project. Everything you’ve learned across the previous fourteen lessons comes together here: prompt engineering, RAG pipelines, chunking strategies, streaming, cost optimization, security, observability, and evaluation.
You’ll build a production-ready knowledge bot that ingests documents, retrieves relevant context, generates streaming answers with source citations, tracks costs, validates inputs, and comes with an evaluation harness. Every file is complete and runnable.
What We’re Building
The knowledge bot has these capabilities:
- Document ingestion — load Markdown, text, and PDF files, chunk them, embed them, and store them in a vector database
- Conversational Q&A — answer questions using retrieved context, maintain conversation history per session
- Streaming responses — Server-Sent Events (SSE) endpoint for real-time token streaming
- Source citation — every answer includes the document chunks that contributed to it
- Input validation — rate limiting, length checks, prompt injection detection
- Cost tracking — per-request token and cost logging
- Evaluation harness — automated quality testing with a test dataset
The stack:
| Component | Technology |
|---|---|
| API framework | FastAPI |
| Vector store | ChromaDB (persistent) |
| Embeddings | OpenAI text-embedding-3-small |
| Generation | OpenAI GPT-4o |
| Streaming | Server-Sent Events (SSE) |
| Configuration | Environment variables + Pydantic |
Project Structure
knowledge-bot/
├── config.py # Configuration and environment variables
├── ingest.py # Document loading, chunking, embedding
├── retriever.py # Vector search and reranking
├── generator.py # LLM generation with streaming
├── api.py # FastAPI application and endpoints
├── observability.py # Cost tracking and logging
├── security.py # Input validation and injection detection
├── evaluate.py # Evaluation harness
├── requirements.txt # Dependencies
├── Dockerfile # Container build
├── .env.example # Environment variable template
└── tests/
└── test_dataset.json # Evaluation test casesConfiguration
Start with a centralized configuration module. Every tunable parameter lives here — no magic numbers scattered across files.
# config.py
import os
from pydantic_settings import BaseSettings
from typing import Optional
class Settings(BaseSettings):
"""All configuration in one place. Override with environment variables."""
# OpenAI
openai_api_key: str
embedding_model: str = "text-embedding-3-small"
generation_model: str = "gpt-4o"
generation_temperature: float = 0.1
generation_max_tokens: int = 1024
# ChromaDB
chroma_persist_dir: str = "./chroma_data"
collection_name: str = "knowledge_base"
# Chunking
chunk_size: int = 512
chunk_overlap: int = 64
# Retrieval
top_k: int = 5
similarity_threshold: float = 0.3
# API
max_query_length: int = 1000
max_history_messages: int = 10
rate_limit_per_minute: int = 30
# Cost tracking
embedding_cost_per_token: float = 0.02 / 1_000_000 # text-embedding-3-small
generation_input_cost_per_token: float = 2.50 / 1_000_000 # gpt-4o input
generation_output_cost_per_token: float = 10.00 / 1_000_000 # gpt-4o output
class Config:
env_file = ".env"
settings = Settings()And the environment template:
# .env.example
OPENAI_API_KEY=sk-your-key-here
EMBEDDING_MODEL=text-embedding-3-small
GENERATION_MODEL=gpt-4o
CHROMA_PERSIST_DIR=./chroma_data
CHUNK_SIZE=512
CHUNK_OVERLAP=64
TOP_K=5Document Ingestion Pipeline
The ingestion pipeline follows the pattern from Lessons 5 and 7: load files, chunk them, embed the chunks, and store them in the vector database.
# ingest.py
import os
import hashlib
from pathlib import Path
from typing import Generator
import chromadb
import openai
from chromadb.config import Settings as ChromaSettings
from config import settings
# --- Document Loading ---
def load_text_file(path: str) -> dict:
"""Load a plain text or markdown file."""
with open(path, "r", encoding="utf-8") as f:
content = f.read()
return {
"content": content,
"source": os.path.basename(path),
"path": path,
"type": path.split(".")[-1]
}
def load_pdf_file(path: str) -> dict:
"""Load a PDF file. Requires pypdf."""
try:
from pypdf import PdfReader
except ImportError:
raise ImportError("Install pypdf: pip install pypdf")
reader = PdfReader(path)
pages = []
for i, page in enumerate(reader.pages):
text = page.extract_text()
if text.strip():
pages.append(text)
return {
"content": "\n\n".join(pages),
"source": os.path.basename(path),
"path": path,
"type": "pdf",
"page_count": len(reader.pages)
}
def load_document(path: str) -> dict:
"""Load a document based on file extension."""
ext = path.lower().split(".")[-1]
if ext in ("txt", "md", "markdown"):
return load_text_file(path)
elif ext == "pdf":
return load_pdf_file(path)
else:
raise ValueError(f"Unsupported file type: {ext}. Supported: txt, md, pdf")
def load_directory(directory: str) -> list[dict]:
"""Load all supported documents from a directory."""
supported = {".txt", ".md", ".markdown", ".pdf"}
documents = []
for path in sorted(Path(directory).rglob("*")):
if path.suffix.lower() in supported and path.is_file():
try:
doc = load_document(str(path))
documents.append(doc)
print(f" Loaded: {path.name} ({len(doc['content'])} chars)")
except Exception as e:
print(f" Error loading {path.name}: {e}")
return documents
# --- Chunking ---
def chunk_document(document: dict, chunk_size: int = None, overlap: int = None) -> list[dict]:
"""
Split a document into overlapping chunks.
Uses recursive character splitting (Lesson 7 approach).
"""
chunk_size = chunk_size or settings.chunk_size
overlap = overlap or settings.chunk_overlap
text = document["content"]
source = document["source"]
# Split on paragraph boundaries first, then sentences, then characters
separators = ["\n\n", "\n", ". ", " ", ""]
chunks = _recursive_split(text, separators, chunk_size, overlap)
result = []
for i, chunk_text in enumerate(chunks):
chunk_id = hashlib.md5(f"{source}:{i}:{chunk_text[:50]}".encode()).hexdigest()[:12]
result.append({
"id": chunk_id,
"text": chunk_text,
"metadata": {
"source": source,
"chunk_index": i,
"total_chunks": len(chunks), # Will be updated after all chunks are created
"char_count": len(chunk_text),
}
})
# Update total_chunks now that we know the count
for chunk in result:
chunk["metadata"]["total_chunks"] = len(result)
return result
def _recursive_split(text: str, separators: list[str], chunk_size: int, overlap: int) -> list[str]:
"""Recursively split text using a hierarchy of separators."""
if len(text) <= chunk_size:
return [text] if text.strip() else []
# Find the best separator
separator = separators[-1]
for sep in separators:
if sep in text:
separator = sep
break
parts = text.split(separator)
chunks = []
current_chunk = ""
for part in parts:
piece = part if not separator else part + separator
if len(current_chunk) + len(piece) <= chunk_size:
current_chunk += piece
else:
if current_chunk.strip():
chunks.append(current_chunk.strip())
# Start new chunk with overlap from previous
if overlap > 0 and current_chunk:
overlap_text = current_chunk[-overlap:]
current_chunk = overlap_text + piece
else:
current_chunk = piece
if current_chunk.strip():
chunks.append(current_chunk.strip())
return chunks
# --- Embedding ---
def embed_texts(texts: list[str], model: str = None) -> list[list[float]]:
"""Embed a batch of texts using OpenAI's embedding API."""
model = model or settings.embedding_model
client = openai.OpenAI(api_key=settings.openai_api_key)
# OpenAI allows up to 2048 texts per batch
batch_size = 512
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
response = client.embeddings.create(
model=model,
input=batch
)
batch_embeddings = [item.embedding for item in response.data]
all_embeddings.extend(batch_embeddings)
print(f" Embedded batch {i // batch_size + 1}: {len(batch)} texts")
return all_embeddings
# --- Vector Store ---
def get_chroma_client() -> chromadb.ClientAPI:
"""Get or create a persistent ChromaDB client."""
return chromadb.PersistentClient(
path=settings.chroma_persist_dir,
settings=ChromaSettings(anonymized_telemetry=False)
)
def get_collection(client: chromadb.ClientAPI = None) -> chromadb.Collection:
"""Get or create the knowledge base collection."""
client = client or get_chroma_client()
return client.get_or_create_collection(
name=settings.collection_name,
metadata={"hnsw:space": "cosine"}
)
def store_chunks(chunks: list[dict], collection: chromadb.Collection = None):
"""Embed and store chunks in ChromaDB."""
collection = collection or get_collection()
texts = [c["text"] for c in chunks]
ids = [c["id"] for c in chunks]
metadatas = [c["metadata"] for c in chunks]
print(f" Embedding {len(texts)} chunks...")
embeddings = embed_texts(texts)
# ChromaDB batch limit is 41666
batch_size = 5000
for i in range(0, len(ids), batch_size):
end = i + batch_size
collection.upsert(
ids=ids[i:end],
embeddings=embeddings[i:end],
documents=texts[i:end],
metadatas=metadatas[i:end]
)
print(f" Stored batch: {min(end, len(ids))}/{len(ids)} chunks")
print(f" Total chunks in collection: {collection.count()}")
# --- Main Ingestion ---
def ingest(source_dir: str):
"""Full ingestion pipeline: load → chunk → embed → store."""
print(f"Loading documents from {source_dir}...")
documents = load_directory(source_dir)
print(f"Loaded {len(documents)} documents.")
all_chunks = []
for doc in documents:
chunks = chunk_document(doc)
all_chunks.extend(chunks)
print(f" {doc['source']}: {len(chunks)} chunks")
print(f"\nTotal chunks: {len(all_chunks)}")
store_chunks(all_chunks)
print("Ingestion complete.")
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
print("Usage: python ingest.py <source_directory>")
sys.exit(1)
ingest(sys.argv[1])Run the ingestion:
python ingest.py ./documents/Retriever
The retriever handles vector search and result formatting. It takes a query, embeds it, searches the vector store, and returns ranked chunks with their metadata.
# retriever.py
from typing import Optional
import chromadb
import openai
from config import settings
from ingest import get_collection, embed_texts
class Retriever:
"""Vector search with filtering and formatting."""
def __init__(self, collection: Optional[chromadb.Collection] = None):
self.collection = collection or get_collection()
self.client = openai.OpenAI(api_key=settings.openai_api_key)
def search(self, query: str, top_k: int = None, source_filter: str = None) -> list[dict]:
"""
Search for relevant chunks.
Returns a list of dicts with text, source, score, and metadata.
"""
top_k = top_k or settings.top_k
# Embed the query
query_embedding = embed_texts([query])[0]
# Build where filter
where_filter = None
if source_filter:
where_filter = {"source": source_filter}
# Search
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=top_k,
where=where_filter,
include=["documents", "metadatas", "distances"]
)
# Format results
chunks = []
for i in range(len(results["ids"][0])):
distance = results["distances"][0][i]
# ChromaDB cosine distance: 0 = identical, 2 = opposite
# Convert to similarity score: 1 - (distance / 2)
similarity = 1 - (distance / 2)
if similarity < settings.similarity_threshold:
continue
chunks.append({
"id": results["ids"][0][i],
"text": results["documents"][0][i],
"source": results["metadatas"][0][i].get("source", "unknown"),
"chunk_index": results["metadatas"][0][i].get("chunk_index", 0),
"similarity": round(similarity, 4),
"metadata": results["metadatas"][0][i]
})
# Sort by similarity (highest first)
chunks.sort(key=lambda x: x["similarity"], reverse=True)
return chunks
def format_context(self, chunks: list[dict]) -> str:
"""Format retrieved chunks into a context string for the LLM."""
if not chunks:
return "No relevant documents found."
context_parts = []
for i, chunk in enumerate(chunks, 1):
context_parts.append(
f"[Source {i}: {chunk['source']}] (relevance: {chunk['similarity']:.0%})\n"
f"{chunk['text']}"
)
return "\n\n---\n\n".join(context_parts)
def get_source_citations(self, chunks: list[dict]) -> list[dict]:
"""Extract citation information from chunks."""
seen_sources = set()
citations = []
for chunk in chunks:
source = chunk["source"]
if source not in seen_sources:
seen_sources.add(source)
citations.append({
"source": source,
"relevance": chunk["similarity"],
"chunk_count": sum(1 for c in chunks if c["source"] == source)
})
return citationsGenerator
The generator handles LLM calls — both streaming and non-streaming. It builds the prompt from the query, context, and conversation history.
# generator.py
import json
from typing import AsyncGenerator, Optional
import openai
from config import settings
from observability import track_generation
SYSTEM_PROMPT = """You are a knowledgeable assistant that answers questions based on the provided context.
Rules:
1. Answer ONLY based on the provided context. If the context doesn't contain enough information, say so clearly.
2. Cite your sources using [Source N] notation when referencing specific information.
3. Be concise but thorough. Prefer short paragraphs over walls of text.
4. If the user asks a follow-up question, use the conversation history for context but still ground your answer in the retrieved documents.
5. Never make up information. If you're unsure, say "Based on the available documents, I cannot confirm this."
"""
def build_messages(
query: str,
context: str,
history: Optional[list[dict]] = None
) -> list[dict]:
"""Build the message list for the LLM call."""
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
# Add conversation history (limited to prevent token bloat)
if history:
recent_history = history[-(settings.max_history_messages * 2):]
messages.extend(recent_history)
# Add the current query with context
user_message = f"""Context from knowledge base:
{context}
Question: {query}"""
messages.append({"role": "user", "content": user_message})
return messages
def generate(
query: str,
context: str,
history: Optional[list[dict]] = None
) -> dict:
"""Generate a non-streaming response."""
client = openai.OpenAI(api_key=settings.openai_api_key)
messages = build_messages(query, context, history)
response = client.chat.completions.create(
model=settings.generation_model,
messages=messages,
temperature=settings.generation_temperature,
max_tokens=settings.generation_max_tokens,
)
usage = response.usage
cost = track_generation(usage.prompt_tokens, usage.completion_tokens)
return {
"answer": response.choices[0].message.content,
"model": settings.generation_model,
"usage": {
"prompt_tokens": usage.prompt_tokens,
"completion_tokens": usage.completion_tokens,
"total_tokens": usage.total_tokens,
"cost_usd": cost
},
"finish_reason": response.choices[0].finish_reason
}
async def generate_stream(
query: str,
context: str,
history: Optional[list[dict]] = None
) -> AsyncGenerator[str, None]:
"""Generate a streaming response using SSE format."""
client = openai.OpenAI(api_key=settings.openai_api_key)
messages = build_messages(query, context, history)
stream = client.chat.completions.create(
model=settings.generation_model,
messages=messages,
temperature=settings.generation_temperature,
max_tokens=settings.generation_max_tokens,
stream=True,
stream_options={"include_usage": True}
)
full_response = ""
prompt_tokens = 0
completion_tokens = 0
for chunk in stream:
# Handle usage info (comes in the last chunk)
if chunk.usage:
prompt_tokens = chunk.usage.prompt_tokens
completion_tokens = chunk.usage.completion_tokens
if chunk.choices and chunk.choices[0].delta.content:
token = chunk.choices[0].delta.content
full_response += token
# SSE format: each event is "data: ...\n\n"
yield f"data: {json.dumps({'token': token})}\n\n"
# Send final event with metadata
cost = track_generation(prompt_tokens, completion_tokens)
yield f"data: {json.dumps({'done': True, 'usage': {'prompt_tokens': prompt_tokens, 'completion_tokens': completion_tokens, 'cost_usd': cost}})}\n\n"Observability
Cost tracking and structured logging — keeping it simple and dependency-free, using the patterns from Lesson 14.
# observability.py
import json
import time
import logging
from datetime import datetime
from collections import defaultdict
from dataclasses import dataclass, field, asdict
from typing import Optional
logger = logging.getLogger("knowledge_bot")
logger.setLevel(logging.INFO)
if not logger.handlers:
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(handler)
from config import settings
# --- Cost tracking ---
_daily_costs = defaultdict(float)
_request_count = defaultdict(int)
def track_embedding(token_count: int) -> float:
"""Track embedding cost."""
cost = token_count * settings.embedding_cost_per_token
today = datetime.utcnow().strftime("%Y-%m-%d")
_daily_costs[today] += cost
return round(cost, 8)
def track_generation(prompt_tokens: int, completion_tokens: int) -> float:
"""Track generation cost."""
cost = (
prompt_tokens * settings.generation_input_cost_per_token +
completion_tokens * settings.generation_output_cost_per_token
)
today = datetime.utcnow().strftime("%Y-%m-%d")
_daily_costs[today] += cost
_request_count[today] += 1
return round(cost, 6)
def get_daily_stats() -> dict:
"""Get today's cost and usage stats."""
today = datetime.utcnow().strftime("%Y-%m-%d")
return {
"date": today,
"total_cost_usd": round(_daily_costs.get(today, 0), 4),
"request_count": _request_count.get(today, 0),
"avg_cost_per_request": round(
_daily_costs.get(today, 0) / max(_request_count.get(today, 0), 1), 6
)
}
# --- Request logging ---
@dataclass
class RequestLog:
request_id: str
query: str # Truncated for privacy
chunks_retrieved: int
top_similarity: float
prompt_tokens: int
completion_tokens: int
cost_usd: float
latency_ms: float
sources: list[str]
error: Optional[str] = None
timestamp: str = field(default_factory=lambda: datetime.utcnow().isoformat())
def log_request(log: RequestLog):
"""Emit structured log for the request."""
data = asdict(log)
data["query"] = data["query"][:100] # Truncate for privacy
logger.info(json.dumps(data, default=str))Security
Input validation and basic prompt injection detection — applying the patterns from Lessons 11 and 12.
# security.py
import re
import time
from collections import defaultdict
from typing import Optional
from config import settings
# --- Rate limiting ---
_request_timestamps: dict[str, list[float]] = defaultdict(list)
def check_rate_limit(client_id: str) -> bool:
"""Simple sliding window rate limiter. Returns True if request is allowed."""
now = time.time()
window = 60.0 # 1 minute window
# Clean old timestamps
_request_timestamps[client_id] = [
ts for ts in _request_timestamps[client_id]
if now - ts < window
]
if len(_request_timestamps[client_id]) >= settings.rate_limit_per_minute:
return False
_request_timestamps[client_id].append(now)
return True
# --- Input validation ---
def validate_query(query: str) -> tuple[bool, Optional[str]]:
"""
Validate a user query.
Returns (is_valid, error_message).
"""
if not query or not query.strip():
return False, "Query cannot be empty."
if len(query) > settings.max_query_length:
return False, f"Query too long. Maximum {settings.max_query_length} characters."
if len(query.strip()) < 3:
return False, "Query too short. Please ask a complete question."
return True, None
# --- Prompt injection detection ---
INJECTION_PATTERNS = [
r"ignore\s+(all\s+)?previous\s+instructions",
r"ignore\s+(all\s+)?above\s+instructions",
r"disregard\s+(all\s+)?previous",
r"forget\s+(everything|all|your)\s+(instructions|rules|prompts)",
r"you\s+are\s+now\s+(a|an)\s+",
r"new\s+instructions?:\s*",
r"system\s*prompt\s*:",
r"<\s*system\s*>",
r"\[INST\]",
r"###\s*(system|instruction|human|assistant)\s*:",
r"pretend\s+(you\s+are|to\s+be)",
r"act\s+as\s+(if\s+you\s+are|a|an)\s+",
r"override\s+(your\s+)?(instructions|rules|system)",
r"reveal\s+(your\s+)?(system\s+)?(prompt|instructions)",
r"what\s+is\s+your\s+system\s+prompt",
r"output\s+your\s+(initial|system|first)\s+(prompt|instructions|message)",
]
def detect_injection(query: str) -> tuple[bool, Optional[str]]:
"""
Check for common prompt injection patterns.
Returns (is_suspicious, matched_pattern).
"""
query_lower = query.lower()
for pattern in INJECTION_PATTERNS:
if re.search(pattern, query_lower):
return True, pattern
return False, None
def sanitize_input(query: str) -> str:
"""Basic input sanitization."""
# Remove null bytes
query = query.replace("\x00", "")
# Normalize whitespace
query = " ".join(query.split())
return query.strip()FastAPI Application
The API layer ties everything together. It exposes endpoints for querying (both streaming and non-streaming), ingestion status, and health checks.
# api.py
import uuid
import time
import json
from typing import Optional
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from config import settings
from retriever import Retriever
from generator import generate, generate_stream, build_messages
from observability import log_request, RequestLog, get_daily_stats
from security import check_rate_limit, validate_query, detect_injection, sanitize_input
from ingest import get_collection
app = FastAPI(
title="Knowledge Bot API",
version="1.0.0",
description="RAG-powered knowledge assistant"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Restrict in production
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
# Initialize retriever
retriever = Retriever()
# In-memory session store (use Redis in production)
sessions: dict[str, list[dict]] = {}
# --- Request/Response models ---
class QueryRequest(BaseModel):
query: str = Field(..., min_length=1, max_length=1000, description="The question to ask")
session_id: Optional[str] = Field(None, description="Session ID for conversation history")
stream: bool = Field(False, description="Whether to stream the response")
top_k: Optional[int] = Field(None, ge=1, le=20, description="Number of documents to retrieve")
class QueryResponse(BaseModel):
answer: str
sources: list[dict]
session_id: str
usage: dict
class HealthResponse(BaseModel):
status: str
collection_count: int
daily_stats: dict
# --- Endpoints ---
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint with collection stats."""
try:
collection = get_collection()
count = collection.count()
except Exception:
count = -1
return HealthResponse(
status="ok" if count >= 0 else "degraded",
collection_count=count,
daily_stats=get_daily_stats()
)
@app.post("/query")
async def query_knowledge_base(request: Request, body: QueryRequest):
"""
Query the knowledge bot. Supports both streaming and non-streaming responses.
"""
request_id = str(uuid.uuid4())[:8]
start_time = time.time()
# Rate limiting
client_ip = request.client.host
if not check_rate_limit(client_ip):
raise HTTPException(status_code=429, detail="Rate limit exceeded. Try again in a minute.")
# Sanitize and validate
query = sanitize_input(body.query)
is_valid, error = validate_query(query)
if not is_valid:
raise HTTPException(status_code=400, detail=error)
# Injection detection
is_suspicious, pattern = detect_injection(query)
if is_suspicious:
raise HTTPException(
status_code=400,
detail="Your query could not be processed. Please rephrase your question."
)
# Session management
session_id = body.session_id or str(uuid.uuid4())
history = sessions.get(session_id, [])
# Retrieve relevant chunks
top_k = body.top_k or settings.top_k
chunks = retriever.search(query, top_k=top_k)
context = retriever.format_context(chunks)
citations = retriever.get_source_citations(chunks)
if body.stream:
# Streaming response
async def event_generator():
# First, send the sources
yield f"data: {json.dumps({'sources': citations})}\n\n"
# Then stream the answer tokens
async for event in generate_stream(query, context, history):
yield event
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Request-ID": request_id,
"X-Session-ID": session_id,
}
)
else:
# Non-streaming response
result = generate(query, context, history)
# Update session history
history.append({"role": "user", "content": query})
history.append({"role": "assistant", "content": result["answer"]})
sessions[session_id] = history
latency_ms = (time.time() - start_time) * 1000
# Log the request
log_request(RequestLog(
request_id=request_id,
query=query,
chunks_retrieved=len(chunks),
top_similarity=chunks[0]["similarity"] if chunks else 0.0,
prompt_tokens=result["usage"]["prompt_tokens"],
completion_tokens=result["usage"]["completion_tokens"],
cost_usd=result["usage"]["cost_usd"],
latency_ms=round(latency_ms, 1),
sources=[c["source"] for c in chunks],
))
return QueryResponse(
answer=result["answer"],
sources=citations,
session_id=session_id,
usage=result["usage"]
)
@app.get("/stats")
async def get_stats():
"""Get usage statistics."""
return get_daily_stats()
@app.delete("/sessions/{session_id}")
async def clear_session(session_id: str):
"""Clear conversation history for a session."""
if session_id in sessions:
del sessions[session_id]
return {"status": "cleared"}
raise HTTPException(status_code=404, detail="Session not found")Run the API:
uvicorn api:app --host 0.0.0.0 --port 8000 --reloadTesting the API
Non-streaming query:
curl -X POST http://localhost:8000/query \
-H "Content-Type: application/json" \
-d '{"query": "What is our return policy?", "stream": false}'Streaming query:
curl -X POST http://localhost:8000/query \
-H "Content-Type: application/json" \
-d '{"query": "How do I reset my password?", "stream": true}'With conversation history (pass the session_id from a previous response):
curl -X POST http://localhost:8000/query \
-H "Content-Type: application/json" \
-d '{"query": "Can you elaborate on that?", "session_id": "abc-123", "stream": false}'Evaluation Harness
The evaluation suite runs automated quality checks against a test dataset. It tests retrieval quality (did we find the right chunks?) and generation quality (is the answer correct?).
# evaluate.py
import json
import time
from typing import Optional
from dataclasses import dataclass, asdict
import openai
from config import settings
from retriever import Retriever
from generator import generate
@dataclass
class TestCase:
query: str
expected_keywords: list[str] # Keywords that should appear in the answer
expected_source: Optional[str] = None # Source document that should be retrieved
category: str = "general"
@dataclass
class EvalResult:
query: str
answer: str
sources_found: list[str]
keyword_score: float # 0-1: fraction of expected keywords found
source_hit: bool # Was the expected source document retrieved?
latency_ms: float
tokens_used: int
cost_usd: float
passed: bool
def load_test_dataset(path: str) -> list[TestCase]:
"""Load test cases from a JSON file."""
with open(path, "r") as f:
data = json.load(f)
return [TestCase(**tc) for tc in data]
def keyword_score(answer: str, expected_keywords: list[str]) -> float:
"""Calculate what fraction of expected keywords appear in the answer."""
if not expected_keywords:
return 1.0
answer_lower = answer.lower()
found = sum(1 for kw in expected_keywords if kw.lower() in answer_lower)
return found / len(expected_keywords)
def llm_judge_score(query: str, answer: str, expected_keywords: list[str]) -> float:
"""
Use an LLM to judge answer quality.
More expensive but catches nuance that keyword matching misses.
"""
client = openai.OpenAI(api_key=settings.openai_api_key)
judge_prompt = f"""You are evaluating the quality of an AI assistant's answer.
Question: {query}
Expected topics/keywords: {', '.join(expected_keywords)}
Answer given: {answer}
Rate the answer from 0.0 to 1.0:
- 1.0 = Complete, accurate, addresses the question fully
- 0.7 = Mostly correct but missing some details
- 0.4 = Partially correct with significant gaps
- 0.0 = Wrong, irrelevant, or refuses to answer
Respond with ONLY a number between 0.0 and 1.0."""
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": judge_prompt}],
temperature=0.0,
max_tokens=10
)
try:
score = float(response.choices[0].message.content.strip())
return max(0.0, min(1.0, score))
except ValueError:
return 0.5 # Default if parsing fails
def run_evaluation(
test_cases: list[TestCase],
use_llm_judge: bool = False,
pass_threshold: float = 0.6
) -> dict:
"""
Run the full evaluation suite.
Returns aggregate metrics and per-case results.
"""
retriever = Retriever()
results: list[EvalResult] = []
print(f"Running evaluation with {len(test_cases)} test cases...")
print(f"Pass threshold: {pass_threshold}")
print(f"LLM judge: {'enabled' if use_llm_judge else 'disabled'}")
print("-" * 60)
total_start = time.time()
for i, tc in enumerate(test_cases, 1):
start = time.time()
# Retrieve
chunks = retriever.search(tc.query)
context = retriever.format_context(chunks)
sources = [c["source"] for c in chunks]
# Generate
result = generate(tc.query, context)
answer = result["answer"]
latency = (time.time() - start) * 1000
# Score
kw_score = keyword_score(answer, tc.expected_keywords)
if use_llm_judge:
judge_score = llm_judge_score(tc.query, answer, tc.expected_keywords)
final_score = (kw_score + judge_score) / 2
else:
final_score = kw_score
source_hit = tc.expected_source in sources if tc.expected_source else True
passed = final_score >= pass_threshold and source_hit
eval_result = EvalResult(
query=tc.query,
answer=answer[:200], # Truncate for readability
sources_found=sources,
keyword_score=round(kw_score, 3),
source_hit=source_hit,
latency_ms=round(latency, 1),
tokens_used=result["usage"]["total_tokens"],
cost_usd=result["usage"]["cost_usd"],
passed=passed
)
results.append(eval_result)
status = "PASS" if passed else "FAIL"
print(f" [{status}] {i}/{len(test_cases)}: {tc.query[:60]}... (score: {final_score:.2f})")
total_time = time.time() - total_start
# Aggregate metrics
passed_count = sum(1 for r in results if r.passed)
total_count = len(results)
summary = {
"total_cases": total_count,
"passed": passed_count,
"failed": total_count - passed_count,
"pass_rate": round(passed_count / max(total_count, 1) * 100, 1),
"avg_keyword_score": round(sum(r.keyword_score for r in results) / max(total_count, 1), 3),
"avg_latency_ms": round(sum(r.latency_ms for r in results) / max(total_count, 1), 1),
"total_tokens": sum(r.tokens_used for r in results),
"total_cost_usd": round(sum(r.cost_usd for r in results), 4),
"total_time_seconds": round(total_time, 1),
"source_hit_rate": round(
sum(1 for r in results if r.source_hit) / max(total_count, 1) * 100, 1
),
}
print("-" * 60)
print(f"Results: {passed_count}/{total_count} passed ({summary['pass_rate']}%)")
print(f"Avg keyword score: {summary['avg_keyword_score']}")
print(f"Avg latency: {summary['avg_latency_ms']}ms")
print(f"Total cost: ${summary['total_cost_usd']}")
print(f"Total time: {summary['total_time_seconds']}s")
# Show failures
failures = [r for r in results if not r.passed]
if failures:
print(f"\nFailed cases ({len(failures)}):")
for r in failures:
print(f" - {r.query[:80]}... (keyword_score={r.keyword_score}, source_hit={r.source_hit})")
return {
"summary": summary,
"results": [asdict(r) for r in results],
"failures": [asdict(r) for r in failures]
}
if __name__ == "__main__":
import sys
dataset_path = sys.argv[1] if len(sys.argv) > 1 else "tests/test_dataset.json"
use_judge = "--llm-judge" in sys.argv
test_cases = load_test_dataset(dataset_path)
results = run_evaluation(test_cases, use_llm_judge=use_judge)
# Save results
output_path = "eval_results.json"
with open(output_path, "w") as f:
json.dump(results, f, indent=2)
print(f"\nResults saved to {output_path}")Test Dataset Format
[
{
"query": "What is the return policy?",
"expected_keywords": ["30 days", "refund", "receipt"],
"expected_source": "return-policy.md",
"category": "policy"
},
{
"query": "How do I reset my password?",
"expected_keywords": ["settings", "email", "reset link"],
"expected_source": "account-help.md",
"category": "support"
},
{
"query": "What programming languages do you support?",
"expected_keywords": ["Python", "JavaScript"],
"expected_source": "technical-docs.md",
"category": "technical"
}
]Run the evaluation:
# Keyword matching only (fast, cheap)
python evaluate.py tests/test_dataset.json
# With LLM judge (slower, more accurate, costs money)
python evaluate.py tests/test_dataset.json --llm-judgeDependencies
# requirements.txt
fastapi==0.115.0
uvicorn[standard]==0.30.0
openai==1.50.0
chromadb==0.5.0
pydantic-settings==2.5.0
pypdf==4.3.0
python-dotenv==1.0.0
httpx==0.27.0Deployment with Docker
# Dockerfile
FROM python:3.12-slim
WORKDIR /app
# Install dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy application code
COPY *.py .
# Create data directories
RUN mkdir -p /app/chroma_data /app/tests
# Health check
HEALTHCHECK \
CMD curl -f http://localhost:8000/health || exit 1
EXPOSE 8000
CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "1"]Build and run:
# Build the image
docker build -t knowledge-bot .
# Run with environment variables
docker run -d \
--name knowledge-bot \
-p 8000:8000 \
-e OPENAI_API_KEY=sk-your-key \
-v $(pwd)/chroma_data:/app/chroma_data \
knowledge-botProduction Considerations
Before deploying to production, address these items:
Replace in-memory session store with Redis. The current sessions dict in api.py lives in a single process and is lost on restart. Use Redis with a TTL for session expiry:
import redis
import json
r = redis.Redis(host="localhost", port=6379, db=0)
def get_session(session_id: str) -> list[dict]:
data = r.get(f"session:{session_id}")
return json.loads(data) if data else []
def save_session(session_id: str, history: list[dict], ttl: int = 3600):
r.setex(f"session:{session_id}", ttl, json.dumps(history))Add authentication. The current API is open. In production, add API key validation or OAuth:
from fastapi import Depends, Security
from fastapi.security import APIKeyHeader
api_key_header = APIKeyHeader(name="X-API-Key")
async def verify_api_key(api_key: str = Security(api_key_header)):
if api_key not in VALID_API_KEYS:
raise HTTPException(status_code=403, detail="Invalid API key")
return api_key
# Apply to endpoints:
@app.post("/query", dependencies=[Depends(verify_api_key)])
async def query_knowledge_base(...):
...Use multiple Uvicorn workers behind a reverse proxy. A single worker handles one request at a time (for LLM calls, which are IO-bound, this is fine with async, but you still want redundancy):
uvicorn api:app --host 0.0.0.0 --port 8000 --workers 4Set up log aggregation. Ship the structured JSON logs from observability.py to your centralized logging system (ELK, Datadog, CloudWatch).
Schedule evaluation runs. Run evaluate.py daily via cron or CI to catch quality regressions:
# .github/workflows/eval.yml
name: Daily Eval
on:
schedule:
- cron: '0 6 * * *' # 6 AM UTC daily
jobs:
evaluate:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.12'
- run: pip install -r requirements.txt
- run: python evaluate.py tests/test_dataset.json
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
- uses: actions/upload-artifact@v4
with:
name: eval-results
path: eval_results.jsonExtension Ideas
This project is a foundation. Here are concrete ways to extend it:
Multi-user support. Add user authentication and per-user document collections. Each user sees only their own knowledge base.
Feedback collection. Add a thumbs-up/thumbs-down endpoint. Store feedback with the trace ID so you can correlate user satisfaction with retrieval quality and prompt versions.
@app.post("/feedback")
async def submit_feedback(request_id: str, helpful: bool, comment: str = ""):
# Store in database, linked to the request log
passHybrid search. Combine vector search with keyword search (BM25). ChromaDB supports this natively, or you can use a separate search index.
Reranking. After vector search, rerank results with a cross-encoder model (e.g., Cohere Rerank) for better precision.
Multi-modal ingestion. Add support for images (with vision models), spreadsheets (CSV/Excel parsing), and HTML (with content extraction).
Admin dashboard. Build a simple frontend that shows the cost dashboard, quality metrics, and lets you browse recent queries and answers.
Key Takeaways
-
A production RAG system is more than retrieval + generation. It includes ingestion pipelines, input validation, cost tracking, session management, error handling, evaluation, and deployment infrastructure.
-
Configuration belongs in one place. Every tunable parameter — model names, chunk sizes, token limits, cost rates — should live in a config file or environment variables. No magic numbers in business logic.
-
Chunking quality determines retrieval quality. If your chunks are too large, you waste tokens on irrelevant context. Too small, and you lose coherence. The recursive splitting approach with overlap is a solid default.
-
Always track costs per request. Token costs are variable and can spike unexpectedly. Log prompt tokens, completion tokens, and estimated cost for every LLM call. Set daily budget alerts.
-
Streaming transforms user experience. Users perceive streaming responses as faster even when total latency is the same. The SSE implementation adds minimal complexity and makes the bot feel responsive.
-
Input validation is your first line of defense. Rate limiting, length checks, and prompt injection detection should happen before any tokens are spent. Reject bad input early and cheaply.
-
Evaluation is not optional. Without automated eval, you’re flying blind. Build a test dataset, run it regularly, and track scores over time. Quality drift happens silently — scheduled evals are the only way to catch it.
-
Session management needs persistence in production. In-memory session stores work for prototyping. For production, use Redis or a database with TTL-based expiry.
-
Start simple, then extend. This project is ~400 lines of core logic. It handles the common case well. Add reranking, hybrid search, and multi-modal support only when your evaluation data shows you need them.
-
Every technique in this project was covered in earlier lessons. Prompt engineering (Lesson 3-4), RAG (Lesson 5), chunking (Lesson 7), evaluation (Lesson 8), streaming (Lesson 9), cost optimization (Lesson 10), security (Lesson 11-12), agents (Lesson 13), and observability (Lesson 14). The capstone project is where they all connect.