CAG: Cache-Augmented Generation
Summary
CAG performs better but does not solve the key reason RAG was created small context windows.
Retrieval-Augmented Generation (RAG) is currently(early 2025) the most popular way to use external knowledge in current LLM opperations. RAG allows you to enhance your LLM with data beyond the data it was trained on. Ther are many great RAG solutions and products.
RAG has some drawbacks
- There can be significant retreival latency as it searches for and organizes the correct data.
- There can be errors in the documents/data it selects as results for a query. For example it may select the wrong document or give priority to the wrong document.
- It may interduce security and data issues 2️⃣.
- It introduces complication
- an external application to manage the data (Vector Database)
- a process to continually update this data when the data goes stale
Cache-Augmented Generation (CAG) proposed as an alterative approach 1️⃣ I suggest could make a great complement for RAG. - Is faster then RAG - Is useful where the data size can fit into the context window of your model. - Is less complicated than RAG - Can give better quality results than RAG (BERTScore)
I am going to cover CAG in this post.
What is CAG
CAG prepocess all relavent data and loads this data in the LLM’s extended context windows.
- Prepare Dataset: Preprocess all relevant knowledge documents.
- Preload Context: Load the dataset into the LLM’s extended context window. For example if using Ollama the default context is 2048 however many models suppost up to 128K context. When using the model for CAG you will need to configure the model to use it full context (or as much as you need).
- Cache Inference State: Store the inference state for repeated queries.
Query Model: Directly interact with the model using the cached knowledge.
Generate Outputs: Produce final results without retrieval latency. - Cache Reset: After generating a response the KVCache can be reset to its inital state to prepare for the next query.
What are the benefits and drabacks of CAG
It is faster than RAG and less complicated however it does require a large context size as it is populated with all the documentation. CAG’s primary limitation is the context window size as this will cap the amount of preloaded data.
Feature | RAG | CAG |
---|---|---|
Performance | Performs real-time retrieval of information during inference. This can slow down response times. | Preloads all relevant knowledge into the model’s context, providing faster response times. |
Errors | Subject to potential errors in document selection and ranking. | Minimizes data errors by ensuring holistic context is present. |
Complexity | Integrates retrieval, update and generation components, which increases system complexity. | Simplifies architecture however we do need to maange the KVCache. |
Context | Dynamically added with each new query. | Context from preloaded data. |
Memory Usage | Uses additional memory and resources for external retrieval. | Uses preloaded KV-cache for efficient resource management. However larger context can lead to challenges. |
Correctness | Is a standard solution usually built using fit for puropose tools like postgres. | Is an unsupported solution using the cache to store non transient data which it was not designed to do. |
When should you use CAG
Datasets that don’t change frequently. Small to Medium Dataset Size: Knowledge fits within the LLM’s context window (32k–100k tokens). Low-Latency Use Cases: Scenarios where speed is critical (e.g., real-time chat systems).
Calculating the number of tokens in text
llama-token-counter can be used to calculate tokens.
To give you an idea of tokens I used it to calculate token count for the bible.
Authorized (King James) Version (AKJV)
Characters (without line endings): 4,478,421
Words: 890,227
Lines: 31,105
Document length: 4,544,631
Number of tokens: 1365259
Code to calculate token count
def get_token_count(text, model_name="meta-llama/Meta-Llama-3-8B"):
"""
Calculates the number of tokens in the given text for a specified LLaMA model.
Args:
text: The input text string.
model_name: The name of the LLaMA model (default: "meta-llama/Meta-Llama-3-8B").
Returns:
The number of tokens in the text.
"""
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokens = tokenizer.encode(text)
return len(tokens)
def get_text_from_file(file_path):
"""
Reads the entire contents of a text file.
Args:
file_path: The path to the text file.
Returns:
The contents of the file as a single string.
"""
try:
with open(file_path, 'r') as file:
text_data = file.read()
return text_data
except FileNotFoundError:
print(f"Error: File not found at '{file_path}'")
return None
data_path = os.path.abspath("shakespeare.txt")
text_data = get_text_from_file(data_path)
token_count = get_token_count(text_data)
print(f'Number of tokens in the file {data_path}: {"{:,}".format(token_count)}')
Determine the usable context size for your model
RULER: What’s the Real Context Size of Your Long-Context Language Models? some research from NVIDA which shows an effective lenght for many open source models.
Essentially what they found was as the context lenght increases the model perfromance decreases.
The key point I would take from this research is that even though your model may claim to have a 128K context it may really only support half this amount with acceptable quality.
CAG Implementation
Generate the kv for the model
def generate(model, input_ids: torch.Tensor, past_key_values, max_new_tokens: int = 50) -> torch.Tensor:
"""
Generates a sequence of tokens using the given model.
Args:
model: The language model to use for generation.
input_ids (torch.Tensor): The input token IDs.
past_key_values: The past key values for the model's attention mechanism.
max_new_tokens (int, optional): The maximum number of new tokens to generate. Defaults to 50.
Returns:
torch.Tensor: The generated token IDs, excluding the input tokens.
"""
device = model.model.embed_tokens.weight.device
origin_len = input_ids.shape[-1]
input_ids = input_ids.to(device)
output_ids = input_ids.clone()
next_token = input_ids
with torch.no_grad():
for _ in range(max_new_tokens):
out = model(input_ids=next_token, past_key_values=past_key_values, use_cache=True)
logits = out.logits[:, -1, :]
token = torch.argmax(logits, dim=-1, keepdim=True)
output_ids = torch.cat([output_ids, token], dim=-1)
past_key_values = out.past_key_values
next_token = token.to(device)
if model.config.eos_token_id is not None and token.item() == model.config.eos_token_id:
break
return output_ids[:, origin_len:]
def get_kv_cache(model, tokenizer, prompt: str) -> DynamicCache:
"""
Generates a key-value cache for a given model and prompt.
Args:
model: The language model to use for generating the cache.
tokenizer: The tokenizer associated with the model.
prompt (str): The input prompt for which the cache is generated.
Returns:
DynamicCache: The generated key-value cache.
"""
device = model.model.embed_tokens.weight.device
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
cache = DynamicCache()
with torch.no_grad():
_ = model(input_ids=input_ids, past_key_values=cache, use_cache=True)
return cache
def clean_up(cache: DynamicCache, origin_len: int):
"""
Trims the key_cache and value_cache tensors in the given DynamicCache object.
Args:
cache (DynamicCache): The cache object containing key_cache and value_cache tensors.
origin_len (int): The length to which the tensors should be trimmed.
Returns:
None
"""
for i in range(len(cache.key_cache)):
cache.key_cache[i] = cache.key_cache[i][:, :, :origin_len, :]
cache.value_cache[i] = cache.value_cache[i][:, :, :origin_len, :]
How to use the CAG
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
# load a model, here I chose phi because it is small with a big context
model_name = "microsoft/Phi-3-mini-128k-instruct"
hf_token = os.getenv("HF_TOKEN")
tokenizer = AutoTokenizer.from_pretrained(
model_name, token=hf_token, trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
trust_remote_code=True,
token=hf_token,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"Loaded {model_name}.")
#load a file for testing
with open("genesis.txt", "r", encoding="utf-8") as f:
doc_text = f.read()
system_prompt = f"""
<|system|>
You are an assistant who provides concise factual answers.
You strive to just answer the user's question.
<|user|>
Context:
{doc_text}
Question:
""".strip()
genesis_cache = get_kv_cache(model, tokenizer, system_prompt)
origin_len = genesis_cache.key_cache[0].shape[-2]
print("KV cache built.")
# query the cache
question1 = "Why did God create eve?"
clean_up(genesis_cache, origin_len)
input_ids_q1 = tokenizer(question1 + "\n", return_tensors="pt").input_ids.to(device)
gen_ids_q1 = generate(model, input_ids_q1, genesis_cache)
answer1 = tokenizer.decode(gen_ids_q1[0], skip_special_tokens=True)
print("Q1:", question1)
print("A1:", answer1)
Practical Implementations
The reason we have RAG is we do not have large enough contexts to handle our data requests. CAG does not solve this problem RAG does. CAG does not support dynamic data arguably RAG does not either but it has solutions and approaches for this. Your puttting data into an unmanaged memory area (essentially the cache is not designed for this). Compare it to RAG which typicall backends to an ACID database. LLMs particularilly closed ones may not provide a hook where you can inject your KV structures into the process.
If you have an application the manages a model and you need max performance over a dataset that is relatively static this is a solution. For example you are building a personal chatbot with a users diary or lifetime conversaions with the bot.
References
1️⃣ Don’t Do RAG: When Cache-Augmented Generation is All You Need for Knowledge Tasks
2️⃣ Pirates of the RAG: Adaptively Attacking LLMs to Leak Knowledge Bases
3️⃣ Not All Contexts Are Equal: Teaching LLMs Credibility-aware Generation
RULER: What’s the Real Context Size of Your Long-Context Language Models?
Code examples
CAG notebooks https://github.com/ernanhughes/cag-noteboooks
Has some example code for using CAG.