Optimizing the Training Pipeline and Models#

All Deep Neural Network tasks supported by TAO provide a train command to enable the users to train models. Training can be done on one or more GPUs. The NVIDIA TAO provides a simple command line interface to train a deep-learning model for classification, object detection, and instance segmentation. To speed up the training process, the train command supports multi-GPU training. You can invoke a multi-GPU training session using the --gpus N option, where N is the number of GPUs you want to use. N must be less than the number of GPUs available in the given node for training.

The following optimizations are also included with the train command:

Knowledge Distillation#

Knowledge distillation is a model compression technique in which a smaller, lightweight student model is trained to replicate the behavior of a larger, high-performing teacher model. By transferring knowledge from the teacher to the student, this approach enables efficient deployment of models in resource-constrained environments without a significant loss in accuracy.

The student model learns not only from the ground truth labels but also from the soft targets: the output probabilities (logits) produced by the teacher. These soft targets capture the teacher’s learned representations and subtle inter-class relationships, which can help the student generalize better than if it were trained on labeled data alone.

In addition to output-based distillation (using logits), feature distillation is another common strategy, in which the student is encouraged to match intermediate feature representations from the teacher. This allows the student to learn richer internal representations, often leading to improved performance on complex tasks.

Knowledge distillation is commonly used in scenarios where fast inference, low memory usage, or deployment on edge devices is critical.

Tips and Best Practices#

When applying knowledge distillation in practice:

  • Given a downstream task, we recommend that you plug in the teacher backbone and fine-tune it on the downstream data first. If the model performs well with the teacher, use the fine-tuned teacher to distill a student model that fits your compute budget.

  • If the teacher is ViT-based and the student is ConvNet-based, the student may struggle to learn from the teacher. ViT-to-ViT or ConvNet-to-ConvNet/ViT distillation generally yields better results. In other words, if the student must be a ConvNet, it’s better to use a ConvNet teacher.

  • If the student is ViT-based, consider starting with RADIO models as teachers. For image or video classification tasks, CLIP models may be more effective. For instance-level recognition or segmentation, MAE, ConvNeXtV2, or DINOv2 are strong candidates.

  • Choose the student model architecture based on your target compute budget. Keep in mind that smaller student models often require more training data to optimize effectively.

  • If training data is limited, try increasing the number of training epochs and applying more aggressive data augmentations to improve generalization.

TAO now supports knowledge distillation for several networks:

  • Feature distillation for object detection with RT-DETR

  • Backbone logits distillation over structured and unstructured data for image classification

  • Logits distillation for object detection with DINO

As of 6.25.09, TAO introduces spatial feature distillation, and Phi-Standardization (PHI-S) in the distillation loss. PHI-S is a technique that standardizes the feature maps of the teacher model to improve the distillation performance.

TAO has also unified backbone implementation for classification_pyt and all the downstream tasks, allowing for distillation of the teacher backbones from dowstream trained models to lighter student backbones supported by those tasks.

When choosing the student backbone to distill to, make sure the downstream task supports it. The following is a exhaustive list of options for distill.teacher.backbone.type:

Module

Supported backbones

classification_pyt

  • faster_vit_0_224

  • faster_vit_1_224

  • faster_vit_2_224

  • faster_vit_3_224

  • faster_vit_4_224

  • faster_vit_5_224

  • faster_vit_6_224

  • faster_vit_4_21k_224

  • faster_vit_4_21k_384

  • faster_vit_4_21k_512

  • faster_vit_4_21k_768

  • fan_tiny_12_p16_224

  • fan_small_12_p16_224_se_attn

  • fan_small_12_p16_224

  • fan_base_18_p16_224

  • fan_large_24_p16_224

  • fan_tiny_8_p4_hybrid

  • fan_small_12_p4_hybrid

  • fan_base_16_p4_hybrid

  • fan_large_16_p4_hybrid

  • fan_swin_tiny_patch4_window7_224

  • fan_swin_small_patch4_window7_224

  • fan_swin_base_patch4_window7_224

  • fan_swin_large_patch4_window7_224

  • vit_large_patch14_dinov2_swiglu

  • vit_large_patch14_dinov2_swiglu_legacy

  • vit_giant_patch14_reg4_dinov2_swiglu

  • vit_base_patch16

  • vit_large_patch16

  • vit_huge_patch14

  • efficientvit_b0

  • efficientvit_b1

  • efficientvit_b2

  • efficientvit_b3

  • efficientvit_l0

  • efficientvit_l1

  • efficientvit_l2

  • efficientvit_l3

  • convnext_tiny

  • convnext_small

  • convnext_base

  • convnext_large

  • convnext_xlarge

  • convnextv2_atto

  • convnextv2_femto

  • convnextv2_pico

  • convnextv2_nano

  • convnextv2_tiny

  • convnextv2_base

  • convnextv2_large

  • convnextv2_huge

  • hiera_tiny_224

  • hiera_small_224

  • hiera_base_224

  • hiera_base_plus_224

  • hiera_large_224

  • hiera_huge_224

  • resnet_18

  • resnet_34

  • resnet_50

  • resnet_101

  • resnet_152

  • resnet_18d

  • resnet_34d

  • resnet_50d

  • resnet_101d

  • resnet_152d

  • swin_tiny_patch4_window7_224

  • swin_small_patch4_window7_224

  • swin_base_patch4_window7_224

  • swin_large_patch4_window7_224

  • swin_base_patch4_window12_384

  • swin_large_patch4_window12_384

  • gc_vit_xxtiny

  • gc_vit_xtiny

  • gc_vit_tiny

  • gc_vit_small

  • gc_vit_base

  • gc_vit_large

  • gc_vit_base_384

  • gc_vit_large_384

  • edgenext_xx_small

  • edgenext_x_small

  • edgenext_small

  • edgenext_base

  • edgenext_xx_small_bn_hs

  • edgenext_x_small_bn_hs

  • edgenext_small_bn_hs

  • c_radio_p1_vit_huge_patch16_mlpnorm

  • c_radio_p2_vit_huge_patch16_mlpnorm

  • c_radio_p3_vit_huge_patch16_mlpnorm

  • c_radio_v2_vit_base_patch16

  • c_radio_v2_vit_large_patch16

  • c_radio_v2_vit_huge_patch16

  • c_radio_v3_vit_large_patch16_reg4_dinov2

  • c_radio_v3_vit_base_patch16_reg4_dinov2

  • c_radio_v3_vit_huge_patch16_reg4_dinov2

  • mit_b0

  • mit_b1

  • mit_b2

  • mit_b3

  • mit_b4

  • mit_b5

  • vit_l_14_siglip_clipa_224

  • vit_l_14_siglip_clipa_336

  • vit_h_14_siglip_clipa_224

dino

  • resnet_34

  • resnet_50

  • fan_tiny

  • fan_small

  • fan_base

  • fan_large

  • swin_tiny_224_1k (alias: swin_tiny_patch4_window7_224)

  • swin_base_224_22k (alias: swin_base_patch4_window7_224)

  • swin_base_384_22k (alias: swin_base_patch4_window12_384)

  • swin_large_224_22k (alias: swin_large_patch4_window7_224)

  • swin_large_384_22k (alias: swin_large_patch4_window12_384)

  • efficientvit_b0

  • efficientvit_b1

  • efficientvit_b2

  • efficientvit_b3

  • vit_large_nvdinov2

  • vit_large_dinov2

mal

  • ViT family (arch strings with vit; patch sizes 8/14/16; sizes tiny/small/base/large/huge)

  • fan_tiny_12_p16_224

  • fan_small_12_p16_224

  • fan_base_18_p16_224

  • fan_large_24_p16_224

  • fan_tiny_8_p4_hybrid

  • fan_small_12_p4_hybrid

  • fan_base_16_p4_hybrid

  • fan_large_16_p4_hybrid

grounding_dino

  • resnet_50

  • swin_tiny_224_1k (alias: swin_tiny_patch4_window7_224)

  • swin_base_224_22k (alias: swin_base_patch4_window7_224)

  • swin_base_384_22k (alias: swin_base_patch4_window12_384)

  • swin_large_224_22k (alias: swin_large_patch4_window7_224)

  • swin_large_384_22k (alias: swin_large_patch4_window12_384)

mask_grounding_dino

  • resnet_50

  • swin_tiny_224_1k (alias: swin_tiny_patch4_window7_224)

  • swin_base_224_22k (alias: swin_base_patch4_window7_224)

  • swin_base_384_22k (alias: swin_base_patch4_window12_384)

  • swin_large_224_22k (alias: swin_large_patch4_window7_224)

  • swin_large_384_22k (alias: swin_large_patch4_window12_384)

rtdetr

  • resnet_18

  • resnet_34

  • resnet_50

  • resnet_101

  • convnext_tiny

  • convnext_small

  • convnext_base

  • convnext_large

  • convnext_xlarge

  • convnextv2_atto

  • convnextv2_femto

  • convnextv2_pico

  • convnextv2_nano

  • convnextv2_tiny

  • convnextv2_small

  • convnextv2_base

  • convnextv2_large

  • convnextv2_huge

  • efficientvit_b0

  • efficientvit_b1

  • efficientvit_b2

  • efficientvit_b3

  • efficientvit_l0

  • efficientvit_l1

  • efficientvit_l2

  • efficientvit_l3

  • fan_tiny_8_p4_hybrid

  • fan_small_12_p4_hybrid

  • fan_base_12_p4_hybrid

  • fan_large_12_p4_hybrid

  • edgenext_x_small

  • edgenext_small

  • edgenext_base

  • edgenext_xx_small_bn_hs

  • edgenext_x_small_bn_hs

  • edgenext_small_bn_hs

segformer

  • fan_tiny_8_p4_hybrid

  • fan_small_12_p4_hybrid

  • fan_base_16_p4_hybrid

  • fan_large_16_p4_hybrid

  • mit_b0

  • mit_b1

  • mit_b2

  • mit_b3

  • mit_b4

  • mit_b5

  • vit_large_nvdinov2

  • vit_giant_nvdinov2

  • vit_base_nvclip_16_siglip

  • vit_huge_nvclip_14_sig

  • c_radio_v2_vit_huge_patch16_224

  • c_radio_v2_vit_large_patch16_224

  • c_radio_v2_vit_base_patch16_224

  • c_radio_v3_vit_large_patch16_reg4_dinov2

visual_changenet

  • fan_tiny_8_p4_hybrid

  • fan_small_12_p4_hybrid

  • fan_base_16_p4_hybrid

  • fan_large_16_p4_hybrid

  • vit_large_nvdinov2

  • vit_large_dinov2

  • c_radio_p1_vit_huge_patch16_224_mlpnorm

  • c_radio_p2_vit_huge_patch16_224_mlpnorm

  • c_radio_p3_vit_huge_patch16_224_mlpnorm

  • c_radio_v2_vit_huge_patch16_224

  • c_radio_v2_vit_large_patch16_224

  • c_radio_v2_vit_base_patch16_224

Note

When using a downstream model as the teacher, make sure to set num_classes to 0 and mode to spatial in the distill config.

For more information on distillation for the specific tasks, please refer to the documentation under the distillation section for that network.

  • rtdetr <distillation_spec_file_rtdetr>

  • classification_pyt <distill_the_classification_pyt_model>

  • dino <dino>

Automatic Mixed Precision#

TAO now supports Automatic-Mixed-Precision (AMP) training. DNN training has traditionally relied on training using the IEEE single-precision format for its tensors. With mixed precision training, however, you may use a mixture of FP16 and FP32 operations in the training graph to help speed up training without compromising accuracy. There are several benefits to using AMP:

  • Speed up math-intensive operations such as linear and convolution layers

  • Speed up memory-limited operations by accessing half the bytes compared to single-precision

  • Reduce memory requirements for training models, enabling larger models or larger minibatches

In TAO, enabling AMP is as simple as setting the --use_amp flag on the command line when running the train command. This helps speed up the training by using FP16 tensor cores. Note that AMP is only supported on GPUs with Volta architecture or above.

Model Pruning#

Model pruning is one of the key differentiators for TAO. Pruning involves removing from the neural network nodes that contribute less to the overall accuracy of the model, reducing the overall size of the model, significantly reducing the memory footprint, and increasing inference throughput—all factors that are very important for edge deployment.

Currently, pruning is supported for a subset of Computer Vision models. The following graph provides an example of performance gains achieved when going from an unpruned CV model to a pruned CV model. (Inference was run on an NVIDIA T4; TrafficCamNet, DashCamNet, and PeopleNet are three of the custom pretrained models that are available on NGC.)

../_images/pruned_vs_unpruned.png

Pruned vs Unpruned Performance#

Quantization Aware Training#

TAO supports Quantization-Aware-Training (QAT) for its object detection networks, namely EfficientDet-Tf2 and Classification networks in TensorFlow2. Quantization Aware Training emulates the inference time quantization when training a model that may then be used by downstream inference platforms to generate actual quantized models. The error from quantizating weights and tensors to INT8 is modeled during training, allowing the model to adapt and mitigate the error. During QAT, the model constructed in the training graph is modified to:

  1. Replace existing nodes with nodes that support fake quantization of its weights.

  2. Convert existing activations to ReLU-6 (except the output nodes).

  3. Add Quantize and De-Quantize(QDQ) nodes to compute the dynamic ranges of the intermediate tensors.

The dynamic ranges computed during training are serialized to a cache file at export, which may then be parsed by NVIDIA® TensorRT to create an optimized inference engine. To enable QAT during training, simply set the enable_qat parameter to be true in the training_config field of the corresponding spec file of each of the supported networks. The benefit of QAT training is usually a better accuracy when doing INT8 inference with TensorRT compared with traditional calibration based INT8 TensorRT inference.

Note

The number of scales present in the cache file is less than that generated by the Post Training Quantization technique using TensorRT. This is because the QDQ nodes are added only after operations that are fused by TensorRT (in GPU) eg: operation sequences such as Conv2d -> Bias -> Relu or Conv2d -> Bias -> BatchNormalization -> Activation, whereas during PTQ, the scales are applied to all the intermediate tensors in the model. Also, the final output regression nodes are not quantized in the current training graphs. So these layers currently run in fp32.

Note

When deploying a model with platforms that have DLA, please note that currently using Quantization cache files generated by peeling the scales from the model is not supported, since DLA requires a scale factor for all layers. In order to use a QAT trained model with DLA, we recommend using the post training quantization at export. The Post Training Quantization method takes the current QAT trained model and generates scale factors for all intermediate tensors in the model since the DLA doesn’t fuse operations as done by the GPU. More information about this can be found in the Exporting the Model sections of each app.

The recommended workflow for training a Quantization Aware model is depicted in the diagram below.

../_images/tao_cv_qat_workflow_all_networks.png

Post-Training Quantization#

Post-Training Quantization (PTQ) converts a trained FP32/FP16 model to a lower-precision representation to reduce latency and memory with minimal or no retraining. In TAO, PTQ is provided via the TAO Quant library. See Quantizing a model in TAO (TAO Quant) for the full guide and backend details.

When to Use PTQ#

Used PTQ when you:

  • Cannot or do not want to retrain (for faster turnaround than QAT)

  • Want to establish a performance/accuracy baseline before investing in QAT

  • Are targeting edge deployment where INT8/FP8 inference and memory savings matter

Backends at a glance#

  • TorchAO (weight-only PTQ): The simplest path; quantizes weights, with no calibration loop, and modest speedups.

  • NVIDIA ModelOpt (static PTQ): Quantizes weights and activations; requires calibration data; yields larger speed gains.

Quick Workflow#

  1. Train your model as usual (FP32 or with AMP).

  2. Add a quantize section to your experiment spec, selecting a backend and mode.

  3. Run the task-specific quantize command.

  4. Evaluate with the quantized artifact and validate accuracy on your data.

  5. Deploy. For TensorRT export pipelines, follow the task’s export section and the TAO Quant documentation.

Minimal example (RT-DETR)#

quantize:
  model_path: "/path/to/trained_rtdetr.ckpt"
  results_dir: "/path/to/quantized_output"
  backend: "torchao"            # or "modelopt"
  mode: "weight_only_ptq"       # torchao
  # mode: "static_ptq"          # modelopt
  default_layer_dtype: "native"
  default_activation_dtype: "native"
  layers:
    - module_name: "Linear"
      weights: { dtype: "int8" }
# Quantize
tao model rtdetr quantize -e /path/to/spec.yaml

# Evaluate with the quantized checkpoint
# (set evaluate.is_quantized: true in your spec and point to the produced artifact)
tao model rtdetr evaluate -e /path/to/spec.yaml

Limitations#

  • Current backends: torchao (weight-only PTQ) and modelopt (static PTQ).

  • Modes: PTQ only; QAT support in TAO Quant is planned but not yet available here.

  • Dtypes: INT8 and FP8 (E4M3FN/E5M2).

  • Tasks: classification_pyt and rtdetr.

  • Runtime: PyTorch; ONNX/TensorRT export is experimental.

  • For the most up-to-date, comprehensive list, see Limitations and current status.