Source code for nemo_automodel._transformers.auto_tokenizer

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Callable, Optional, Type, Union

logger = logging.getLogger(__name__)


[docs] def _get_model_type(pretrained_model_name_or_path: str, trust_remote_code: bool = False) -> Optional[str]: """ Determine the model type from the config. Args: pretrained_model_name_or_path: Model identifier or path trust_remote_code: Whether to trust remote code Returns: The model_type string, or None if it cannot be determined """ try: from transformers import AutoConfig config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code) return getattr(config, "model_type", None) except Exception as e: logger.debug(f"Could not load config to determine model type: {e}") return None
[docs] def _get_tokenizer_registry(): # Import lazily to avoid pulling in optional/custom backends (and transformers) # when users only do `from nemo_automodel import NeMoAutoTokenizer`. from nemo_automodel._transformers.tokenization.registry import TokenizerRegistry return TokenizerRegistry
[docs] class NeMoAutoTokenizer: """ Auto tokenizer class that dispatches to appropriate tokenizer implementations. Similar to HuggingFace's AutoTokenizer, but with a custom registry for specialized tokenizer implementations. The dispatch logic is: 1. If a custom tokenizer is registered for the model type, use it 2. Otherwise, fall back to NeMoAutoTokenizerWithBosEosEnforced Example: >>> # Will use MistralCommonBackend if available for Mistral models >>> tokenizer = NeMoAutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") >>> # Force using HF AutoTokenizer with BOS/EOS enforcement >>> tokenizer = NeMoAutoTokenizer.from_pretrained("gpt2", force_default=True) """ # Make registry accessible at class level _registry = None
[docs] @classmethod def register(cls, model_type: str, tokenizer_cls: Union[Type, Callable]) -> None: """ Register a custom tokenizer for a specific model type. Args: model_type: The model type string (e.g., "mistral", "llama") tokenizer_cls: The tokenizer class or factory function """ _get_tokenizer_registry().register(model_type, tokenizer_cls)
[docs] @classmethod def from_pretrained( cls, pretrained_model_name_or_path: str, *args, force_default: bool = False, force_hf: bool = False, trust_remote_code: bool = False, **kwargs, ): """ Load a tokenizer from a pretrained model. Args: pretrained_model_name_or_path: Model identifier or path force_default: If True, always use NeMoAutoTokenizerWithBosEosEnforced force_hf: If True, return the raw HF AutoTokenizer without any wrapping trust_remote_code: Whether to trust remote code when loading config **kwargs: Additional arguments passed to the tokenizer's from_pretrained Returns: A tokenizer instance appropriate for the model type """ # If force_hf, just use the base HF AutoTokenizer if force_hf: from transformers import AutoTokenizer return AutoTokenizer.from_pretrained( pretrained_model_name_or_path, *args, trust_remote_code=trust_remote_code, **kwargs ) # Try to determine model type from config model_type = _get_model_type(pretrained_model_name_or_path, trust_remote_code=trust_remote_code) registry = _get_tokenizer_registry() if not force_default and model_type: tokenizer_cls = registry.get_custom_tokenizer_cls(model_type) if tokenizer_cls is not None: logger.info(f"Using custom tokenizer {tokenizer_cls.__name__} for model type '{model_type}'") return tokenizer_cls.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) # Fall back to default BOS/EOS enforced tokenizer from nemo_automodel._transformers.tokenization.nemo_auto_tokenizer import NeMoAutoTokenizerWithBosEosEnforced return NeMoAutoTokenizerWithBosEosEnforced.from_pretrained( pretrained_model_name_or_path, *args, trust_remote_code=trust_remote_code, **kwargs )
__all__ = [ "NeMoAutoTokenizer", "NeMoAutoTokenizerWithBosEosEnforced", "TokenizerRegistry", ]
[docs] def __getattr__(name: str): if name == "TokenizerRegistry": return _get_tokenizer_registry() if name == "NeMoAutoTokenizerWithBosEosEnforced": from nemo_automodel._transformers.tokenization.nemo_auto_tokenizer import NeMoAutoTokenizerWithBosEosEnforced return NeMoAutoTokenizerWithBosEosEnforced raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
[docs] def __dir__(): return sorted(__all__)