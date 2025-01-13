KV Cache Reuse (a.k.a. prefix caching)#

How to use#

Enabled by setting the environment variable NIM_ENABLE_KV_CACHE_REUSE to 1. See configuration documentation for more information.

When to use#

In scenarios where more than 90% of the initial prompt is identical across multiple requests—differing only in the final tokens—implementing a key-value cache could substantially improve inference speed. This approach leverages a high degree of similarity in the prompts, allowing for efficient reuse of computational resources and minimizing processing time for the variations at the end.

For example, when a user asks questions about a large document, the large document repeats among requests but the question at the end of the prompt is different. When this feature is enabled, there is typically about a 2x speedup in time-to-first-token (TTFT).

Example:

  • Large table input followed by a question about the table

  • Same large table input followed by a different question about the table

  • Same large table input followed by a different question about the table

  • and so forth…

KV Cache reuse will speed up TTFT starting on the second request and following.

You can use the following script to demonstrate the speedup:

import time
import requests
import json

# Define your model endpoint URL
API_URL = "http://0.0.0.0:8000/v1/chat/completions"

# Function to send a request to the API and return the response time
def send_request(model, messages, max_tokens=15):
    data = {
        "model": model,
        "messages": messages,
        "max_tokens": max_tokens,
        "top_p": 1,
        "frequency_penalty": 1.0
    }

    headers = {
        "accept": "application/json",
        "Content-Type": "application/json"
    }

    start_time = time.time()
    response = requests.post(API_URL, headers=headers, data=json.dumps(data))
    end_time = time.time()

    output = response.json()
    print(f"Output: {output['choices'][0]['message']['content']}")
    print(f"Generation time: {end_time - start_time:.4f} seconds")
    return end_time - start_time

# Test function demonstrating caching with a long prompt
def test_prefix_caching():
    model = "your_model_name_here"

    # Long document to simulate complex input
    LONG_PROMPT = """# Table of People\n""" + \
    "| ID  | Name          | Age | Occupation    | Country       |\n" + \
    "|-----|---------------|-----|---------------|---------------|\n" + \
    "| 1   | John Doe      | 29  | Engineer      | USA           |\n" + \
    "| 2   | Jane Smith    | 34  | Doctor        | Canada        |\n" * 50  # Replicating rows to make the table long

    # First query (no caching)
    messages_1 = [{"role": "user", "content": LONG_PROMPT + "Question: What is the age of John Doe?"}]
    print("\nFirst query (no caching):")
    send_request(model, messages_1)

    # Second query (prefix caching enabled)
    messages_2 = [{"role": "user", "content": LONG_PROMPT + "Question: What is the occupation of Jane Smith?"}]
    print("\nSecond query (with prefix caching):")
    send_request(model, messages_2)

if __name__ == "__main__":
    test_prefix_caching()
