Custom Logits Processing#
NIM LLM supports custom logits processing by passing through to vLLM’s native logits processor API. You can write your own processor, mount it into the container as a volume, and enable it with a single CLI flag.
Custom logits processors operate at the batch level. vLLM batches multiple requests into a single tensor for GPU efficiency, so your processor receives the entire batch together. However, you can toggle behavior per request by having clients pass custom arguments in vllm_xargs. Your processor checks each row’s extra_args and modifies only the rows that opt in. Each row in the tensor corresponds to one request.
Run a Custom Logits Processor#
To run a custom logits processor, complete the following steps:
Place your processor script in a local directory.
Volume-mount that directory into the container.
Pass the processor’s module path to NIM by using the
--logits-processorsCLI argument:docker run --gpus all \ -v /home/user/my_processors:/opt/nim/my_processors \ -e PYTHONPATH=/opt/nim \ -e NIM_MODEL_PATH=hf://meta-llama/Llama-3.2-1B-Instruct \ -p 8000:8000 \ nim-llm:local \ nim-serve --logits-processors my_processors.token_filter:BadWordFilterProcessor
The --logits-processors value uses Python’s module.submodule:ClassName format. The code must be importable from PYTHONPATH.
Write a Custom Logits Processor#
Custom processors extend vllm.v1.sample.logits_processor.LogitsProcessor and implement the following five methods:
Method |
Purpose |
|---|---|
|
Class method. Validates per-request parameters when a request arrives. Raise |
|
Runs during server startup. Initialize any state your processor needs. |
|
Called every engine step. Receives the full logits tensor |
|
Return |
|
Called every engine step with a |
Example: BanTokenProcessor
The following processor bans a single token per request. Requests opt in by sending {"vllm_xargs": {"ban_token_id": <int>}}. Requests that do not send this parameter are unaffected.
import torch
from vllm.config import VllmConfig
from vllm.sampling_params import SamplingParams
from vllm.v1.sample.logits_processor import BatchUpdate, LogitsProcessor, MoveDirectionality
class BanTokenProcessor(LogitsProcessor):
@classmethod
def validate_params(cls, params: SamplingParams):
ban = params.extra_args and params.extra_args.get("ban_token_id")
if ban is not None and not isinstance(ban, int):
raise ValueError(f"ban_token_id must be int, got {type(ban).__name__}")
def __init__(self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool):
self.banned: dict[int, int] = {}
def is_argmax_invariant(self) -> bool:
return False
def update_state(self, batch_update: BatchUpdate | None) -> None:
if not batch_update:
return
for index, params, _, _ in batch_update.added:
ban = params.extra_args and params.extra_args.get("ban_token_id")
if ban is not None:
self.banned[index] = ban
else:
self.banned.pop(index, None)
if not self.banned:
return
for index in batch_update.removed:
self.banned.pop(index, None)
for src, dst, direction in batch_update.moved:
src_val = self.banned.pop(src, None)
dst_val = self.banned.pop(dst, None)
if src_val is not None:
self.banned[dst] = src_val
if direction == MoveDirectionality.SWAP and dst_val is not None:
self.banned[src] = dst_val
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if not self.banned:
return logits
for req_idx, token_id in self.banned.items():
logits[req_idx, token_id] = float("-inf")
return logits
Per-Request Toggling#
Custom logits processors operate at the batch level, but clients can enable specific logits processors per request by passing custom arguments using vllm_xargs.
curl -X POST http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "meta-llama/Llama-3.2-1B-Instruct",
"prompt": "Once upon a time...",
"max_tokens": 64,
"vllm_xargs": {"ban_token_id": 42}
}'
For more details on invoking custom logits processors per request, refer to the vLLM documentation.