Custom Logits Processing#

NIM LLM supports custom logits processing by passing through to vLLM’s native logits processor API. You can write your own processor, mount it into the container as a volume, and enable it with a single CLI flag.

Custom logits processors operate at the batch level. vLLM batches multiple requests into a single tensor for GPU efficiency, so your processor receives the entire batch together. However, you can toggle behavior per request by having clients pass custom arguments in vllm_xargs. Your processor checks each row’s extra_args and modifies only the rows that opt in. Each row in the tensor corresponds to one request.

Run a Custom Logits Processor#

To run a custom logits processor, complete the following steps:

  1. Place your processor script in a local directory.

  2. Volume-mount that directory into the container.

  3. Pass the processor’s module path to NIM by using the --logits-processors CLI argument:

    docker run --gpus all \
    -v /home/user/my_processors:/opt/nim/my_processors \
    -e PYTHONPATH=/opt/nim \
    -e NIM_MODEL_PATH=hf://meta-llama/Llama-3.2-1B-Instruct \
    -p 8000:8000 \
    nim-llm:local \
    nim-serve --logits-processors my_processors.token_filter:BadWordFilterProcessor
    

The --logits-processors value uses Python’s module.submodule:ClassName format. The code must be importable from PYTHONPATH.

Write a Custom Logits Processor#

Custom processors extend vllm.v1.sample.logits_processor.LogitsProcessor and implement the following five methods:

Method

Purpose

validate_params

Class method. Validates per-request parameters when a request arrives. Raise ValueError to reject the request.

__init__

Runs during server startup. Initialize any state your processor needs.

apply

Called every engine step. Receives the full logits tensor (num_requests, vocab_size) and returns the modified tensor.

is_argmax_invariant

Return True if your processor never changes which token has the highest logit. Return False otherwise.

update_state

Called every engine step with a BatchUpdate describing which requests were added, removed, or moved in the batch. Use this to track per-request state.

Example: BanTokenProcessor

The following processor bans a single token per request. Requests opt in by sending {"vllm_xargs": {"ban_token_id": <int>}}. Requests that do not send this parameter are unaffected.

import torch
from vllm.config import VllmConfig
from vllm.sampling_params import SamplingParams
from vllm.v1.sample.logits_processor import BatchUpdate, LogitsProcessor, MoveDirectionality


class BanTokenProcessor(LogitsProcessor):

    @classmethod
    def validate_params(cls, params: SamplingParams):
        ban = params.extra_args and params.extra_args.get("ban_token_id")
        if ban is not None and not isinstance(ban, int):
            raise ValueError(f"ban_token_id must be int, got {type(ban).__name__}")

    def __init__(self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool):
        self.banned: dict[int, int] = {}

    def is_argmax_invariant(self) -> bool:
        return False

    def update_state(self, batch_update: BatchUpdate | None) -> None:
        if not batch_update:
            return

        for index, params, _, _ in batch_update.added:
            ban = params.extra_args and params.extra_args.get("ban_token_id")
            if ban is not None:
                self.banned[index] = ban
            else:
                self.banned.pop(index, None)

        if not self.banned:
            return

        for index in batch_update.removed:
            self.banned.pop(index, None)

        for src, dst, direction in batch_update.moved:
            src_val = self.banned.pop(src, None)
            dst_val = self.banned.pop(dst, None)
            if src_val is not None:
                self.banned[dst] = src_val
            if direction == MoveDirectionality.SWAP and dst_val is not None:
                self.banned[src] = dst_val

    def apply(self, logits: torch.Tensor) -> torch.Tensor:
        if not self.banned:
            return logits

        for req_idx, token_id in self.banned.items():
            logits[req_idx, token_id] = float("-inf")

        return logits

Per-Request Toggling#

Custom logits processors operate at the batch level, but clients can enable specific logits processors per request by passing custom arguments using vllm_xargs.

curl -X POST http://localhost:8000/v1/completions \
  -H "Content-Type: application/json" \
  -d '{
    "model": "meta-llama/Llama-3.2-1B-Instruct",
    "prompt": "Once upon a time...",
    "max_tokens": 64,
    "vllm_xargs": {"ban_token_id": 42}
  }'

For more details on invoking custom logits processors per request, refer to the vLLM documentation.