Image Classification PyT#

Image Classification PyT is a PyTorch-based image-classification model included in TAO. It supports the following tasks:

  • train

  • evaluate

  • inference

  • export

  • distill

All above actions follow below command pattern.

SPECS=$(tao-client classification_pyt get-spec --action <sub_task> --job_type experiment --id $EXPERIMENT_ID)

JOB_ID=$(tao-client classification_pyt experiment-run-action --action <sub_task> --id $EXPERIMENT_ID --specs "$SPECS")

Required Arguments

  • --id: The unique identifier of the experiment from which to train the model

See also

For information on how to create an experiment using the FTMS client, refer to the Creating an experiment section in the Remote Client documentation.

Preparing the Input Data Structure#

See the Data Annotation Format page for more information about the data format for image classification.

The train classification experiment specification consists of seven main components:

  • model

  • dataset

  • train

  • evaluate

  • inference

  • export

  • distill

model#

Here is an example model-specification file for Image Classification PyT with a FAN backbone:

We first need to set the base_experiment.

FILTER_PARAMS='{"network_arch": "classification_pyt"}'

$BASE_EXPERIMENTS=$(tao-client classification_pyt list-base-experiments --filter_params "$FILTER_PARAMS")

Retrieve the PTM_ID for FAN backbone from $BASE_EXPERIMENTS before setting base_experiment.

PTM_INFORMATION="{\"base_experiment\": [$PTM_ID]}"

tao-client classification_pyt patch-artifact-metadata --id $EXPERIMENT_ID --job_type experiment --update_info $PTM_INFORMATION

Then retrieve the specifications.

TRAIN_SPECS=$(tao-client classification_pyt get-spec --action train --job_type experiment --id $EXPERIMENT_ID)

Get specifications from $TRAIN_SPECS. You can override values as needed.

The model parameter primarily configures the backbone and head.

Parameter

Datatype

Default

Description

Supported Values

backbone

dict config

The configuration of the backbone.

head

dict config

The configuration of the head.

backbone#

Parameter

Datatype

Default

Description

Supported Values

type



































str



































fan_small_12_p4_hybrid



































Backbone architectures




































FAN Variants
fan_tiny_8_p4_hybrid, fan_small_12_p4_hybrid,
fan_base_16_p4_hybrid, fan_large_16_p4_hybrid,
fan_Xlarge_16_p4_hybrid, fan_base_18_p16_224,
fan_tiny_12_p16_224, fan_small_12_p16_224,
fan_large_24_p16_224, fan_small_12_p16_224_se_attn

GCViT Variants
gc_vit_xxtiny, gc_vit_xtiny, gc_vit_tiny,
gc_vit_small, gc_vit_base, gc_vit_large,

FasterViT Variants
faster_vit_0_224, faster_vit_1_224,
faster_vit_2_224, faster_vit_3_224,
faster_vit_4_224, faster_vit_5_224,
faster_vit_6_224, faster_vit_4_21k_224,
faster_vit_4_21k_384, faster_vit_4_21k_512,
faster_vit_4_21k_768

NVCLIP Variants
ViT-H-14-SigLIP-CLIPA-224
ViT-L-14-SigLIP-CLIPA-336
ViT-L-14-SigLIP-CLIPA-224

NVDIONv2 Variants
vit_large_patch14_dinov2_swiglu
vit_giant_patch14_reg4_dinov2_swiglu

CRADIO Variants
c_radio_p1_vit_huge_patch16_mlpnorm
c_radio_p2_vit_huge_patch16_mlpnorm
c_radio_p3_vit_huge_patch16_mlpnorm
c_radio_v2_vit_base_patch16
c_radio_v2_vit_large_patch16
c_radio_v2_vit_huge_patch16

feat_downsample

bool

False

Feature downsample for fan base backbone

True,False

pretrained_backbone_path

str

Path to the pretrained model

freeze_backbone

bool

False

Flag to freeze backbone

True,False

Foundation Models#

Subset of the supported arch and the pre-train datasets. Please note that the in_channels should be updated under the head :

  • NVCLIP Image Backbones:

Arch

Pretrained Dataset

in_channels

ViT-H-14-SigLIP-CLIPA-224

NVIDIA-commercial dataset

1024

ViT-L-14-SigLIP-CLIPA-336

NVIDIA-commercial dataset

768

ViT-L-14-SigLIP-CLIPA-224

NVIDIA-commercial dataset

768

  • NVDINOv2 Image Backbones:

Arch

Pretrained Dataset

in_channels

vit_large_patch14_dinov2_swiglu

NVIDIA-commercial dataset

1024

vit_giant_patch14_reg4_dinov2_swiglu

NVIDIA-commercial dataset

1536

  • RADIO Image Backbones:

Arch

Pretrained Dataset

in_channels

c_radio_p1_vit_huge_patch16_mlpnorm

NVIDIA-commercial dataset

3840

c_radio_p2_vit_huge_patch16_mlpnorm

NVIDIA-commercial dataset

5120

c_radio_p3_vit_huge_patch16_mlpnorm

NVIDIA-commercial dataset

3840

c_radio_v2_vit_base_patch16

NVIDIA-commercial dataset

2304

c_radio_v2_vit_large_patch16

NVIDIA-commercial dataset

3072

c_radio_v2_vit_huge_patch16

NVIDIA-commercial dataset

3840

loss#

Parameter

Datatype

Default

Description

Supported Values

type

str

CrossEntropyLoss

Loss type.

CrossEntropyLoss

label_smooth_val

float

0.0

Label smoothing value.

Dataset Input for Classification PyT#

Here is an example of dataset specification file for classification PyT:

Note

For FTMS Client, these parameters are set in json format.

dataset:
  dataset: "CLDataset"
  root_dir: /dataset/imagenet2012
  batch_size: 128
  workers: 1
  num_classes: 1000
  img_size: 224
  augmentation:
    mixup_cutmix: True
    random_flip:
      vflip_probability: 0
      hflip_probability: 0.5
      enable: True
    random_aug:
      enable: True
    random_erase:
      enable: True
    random_rotate:
      rotate_probability: 0.5
      angle_list: [90, 180, 270]
      enable: False
    random_color:
      brightness: 0.4
      contrast: 0.4
      saturation: 0.4
      enable: False
    with_scale_random_crop:
      enable: False
    with_random_crop: True
    with_random_blur: False
  train_dataset:
    images_dir: /dataset/imagenet2012/train
  val_dataset:
    images_dir: /dataset/imagenet2012/val
  test_dataset:
    images_dir: /dataset/imagenet2012/test

The table below describes the configurable parameters in dataset.

Parameter

Datatype

Default

Description

Supported Values

root_dir

str

Path to folder that contains classes.txt.

dataset

str

dataset class.

num_classes

int

The number of classes in the training data.

img_size

int

The input image size.

batch_size

int

Batch Size.

workers

int

Workers.

shuffle

bool

Shuffle dataloader.

True,False

augmentation

dict config

Augmentation config.

train_dataset

dict config

Configuration for the training dataset path.

train_nolabel

dict config

Train Data Dataclass.

val_dataset

dict config

Configuration for the validation dataset path.

test_dataset

dict config

Configuration for the testing dataset path.

augmentation#

Parameter

Datatype

Default

Description

Supported Values

random_flip

dict config

RandomFlip augmentation config.

random_rotate

dict config

RandomRotation augmentation config.

random_color

dict config

RandomColor augmentation config.

random_erase

dict config

RandomErase augmentation config.

random_aug

dict config

RandomAug augmentation config.

with_scale_random_crop

dict config

RandomCropWithScale augmentation config.

with_random_blur

bool

Flag to enable with_random_blur.

with_random_crop

bool

Flag to enable with_random_crop.

mean

List[float]

Mean for the augmentation.

std

List[float]

Standard deviation for the augmentation.

mixup_cutmix

bool

False

Flag to enable mixup and cutmix. Not recommended for binary classification.

True,False

mixup_alpha

float

0.4

Mixup alpha.

RandomFlip#

Parameter

Datatype

Default

Description

Supported Values

vflip_probability

float

0.5

Vertical Flip probability.

hflip_probability

float

0.5

Horizontal Flip probability.

enable

bool

True

Flag to enable augmentation.

True,False

RandomRotation#

Parameter

Datatype

Default

Description

Supported Values

rotate_probability

float

0.5

Random Rotate probability.

angle_list

List[float]

[90, 180, 270]

Random rotate angle.

enable

bool

True

Flag to enable augmentation.

True,False

RandomColor#

Parameter

Datatype

Default

Description

Supported Values

brightness

float

0.3

Random Color Brightness.

contrast

float

0.3

Random Color Contrast.

saturation

float

0.3

Random Color Saturation.

hue

float

0.3

Random Color Hue.

enable

bool

True

Flag to enable Random Color.

True,False

color_probability

float

0.5

Random Color Probability.

RandomCropWithScale#

Parameter

Datatype

Default

Description

Supported Values

scale_range

float

[1, 1.2]

Random Scale range.

enable

bool

True

Flag to enable augmentation.

True,False

RandomErase#

Parameter

Datatype

Default

Description

Supported Values

erase_probability

float

0.2

Random Erase Probability.

enable

bool

True

Flag to enable augmentation.

True,False

RandomAug#

Parameter

Datatype

Default

Description

Supported Values

enable

bool

True

Flag to enable augmentation.

True,False

train_dataset#

Parameter

Datatype

Default

Description

Supported Values

images_dir

str

Path to images directory for dataset.

val_dataset#

Parameter

Datatype

Default

Description

Supported Values

images_dir

str

Path to images directory for dataset.

test_dataset#

Parameter

Datatype

Default

Description

Supported Values

images_dir

str

Path to images directory for dataset.

train_nolabel#

Parameter

Datatype

Default

Description

Supported Values

folder_path

Optional[str]

Dataset directory path.

train#

Here is an example of dataset specification file for classification PyT:

Note

For FTMS Client, these parameters are set in json format.

Parameter

Datatype

Default

Description

Supported Values

optim

dict config

Optimizer config.

pretrained_model_path

str

None

Pretrained model path.

tensorboard

dict config

Configuration for the tensorboard logger.

enable_ema

bool

False

Flag to enable EMA.

True,False

ema_decay

float

0.998

EMA decay.

clip_grad_norm

float

2.0

Gradient Norm.

num_gpus

int

1

The number of GPUs to run the train job.

gpu_ids

List[int]

[0]

List of GPU IDs to run the training on.

num_nodes

int

1

Number of nodes to run the training on.

seed

int

1234

The seed for the initializer in PyTorch.

num_epochs

int

10

Number of epochs to run the training.

checkpoint_interval

int

1

Checkpoint interval.

validation_interval

int

1

Validation interval.

resume_training_checkpoint_path

str

None

Path to the checkpoint to resume training

results_dir

str

None

Path to where all the assets are stored.

optim#

Parameter

Datatype

Default

Description

Supported Values

monitor_name

str

val_loss

Monitor Name

optim

str

adamw

Optimizer

adamw,adam,sgd

lr

float

0.00006

Optimizer learning rate

policy

str

linear

Optimizer policy

linear,step,cosine,multistep

policy_params

Dict[str, Any]

{“step_size”: 30, “gamma”: 0.1, “milestones”: [10, 20]}

Optimizer policy parameters

linear,step,cosine,multistep

momentum

float

0.9

The momentum for the AdamW optimizer.

weight_decay

float

0.01

The weight decay coefficient.

betas

List[float]

[0.9, 0.999]

coefficients used for computing running averages on adamw.

skip_names

List[str]

[]

layers names which do not need weight decay.

warmup_epochs

int

0

Warmup epochs.

tensorboard#

Parameter

Datatype

Default

Description

Supported Values

enabled

bool

False

Flag to enable tensorboard

infrequent_logging_frequency

int

2

infrequent_logging_frequency

evaluate#

Here is an example of evaluate specification file for classification PyT:

Note

For FTMS Client, these parameters are set in json format.

evaluate:
  checkpoint: /path/to/model.pth

Parameter

Datatype

Default

Description

Supported Values

vis_after_n_batches

int

1

Visualize evaluation segmentation results after n batches.

batch_size

int

8

Batch Size.

checkpoint

str

Path to checkpoint file.

num_gpus

int

1

The number of GPUs to run the evaluate job.

gpu_ids

List[int]

[0]

List of GPU IDs to run the evaluate on.

num_nodes

int

1

Number of nodes to run the evaluate on.

checkpoint

str

Path to the checkpoint used for evaluation.

trt_engine

Optional[str]

None

Path to the TensorRT engine to be used for evaluation.

results_dir

Optional[str]

None

Path to where all the assets are stored.

inference#

The inference config contains the parameters related to training. They are described as follows:

Note

For FTMS Client, these parameters are set in json format.

inference:
  checkpoint: ${results_dir}/train/model_latest.pth

Parameter

Datatype

Default

Description

Supported Values

vis_after_n_batches

int

1

Visualize inference segmentation results after n batches.

batch_size

int

8

Batch Size.

checkpoint

str

Path to checkpoint file.

num_gpus

int

1

The number of GPUs to run the inference job.

gpu_ids

List[int]

[0]

List of GPU IDs to run the inference on.

num_nodes

int

1

Number of nodes to run the inference on.

checkpoint

str

Path to the checkpoint used for inference.

trt_engine

Optional[str]

None

Path to the TensorRT engine to be used for inference.

results_dir

Optional[str]

None

Path to where all the assets are stored.

export#

The export config contains the parameters related to export. They are described as follows:

Note

For FTMS Client, these parameters are set in json format.

export:
  results_dir: "${results_dir}/export"
  gpu_id: 0
  checkpoint: ${results_dir}/train/model_latest.pth
  onnx_file: "${export.results_dir}/model_latest.onnx"
  input_width: 224
  input_height: 224
  batch_size: -1

Parameter

Datatype

Default

Description

Supported Values

results_dir

Optional[str]

None

Path to where all the assets are stored.

gpu_ids

int

0

The index of the GPU to build the TensorRT engine.

checkpoint

str

Path to the checkpoint file to run export.

onnx_file

str

Path to the onnx model file.

on_cpu

bool

False

Flag to export CPU compatible model.

True,False

input_channel

int

3

Number of channels in the input Tensor.

1,3

input_width

int

960

Width of the input image tensor.

input_height

int

544

Height of the input image tensor.

opset_version

int

17

Operator set version.

batch_size

int

-1

The batch size of the input Tensor for the engine.

distill#

The distill config contains the parameters related to distill. They are described as follows:

Note

For FTMS Client, these parameters are set in json format.

distill:
  teacher:
    backbone:
      type: "vit_large_patch14_dinov2_swiglu"
      pretrained_backbone_path: <pretrained_model_path>
      freeze_backbone: True
  pretrained_teacher_model_path: <pretrained_teacher_path>

Parameter

Datatype

Default

Description

Supported Values

teacher

Dict config

Configuration hyper parameters for the teacher model.

loss_type

str

KL

Loss function for logits distillation.

KL,CE,L1,L2

loss_lambda

float

0.5

The weight to be applied to the distillation loss as compared to task loss.

pretrained_teacher_model_path

str

Path to the pre-trained teacher model.

results_dir

str

Path to where all the assets generated from a task are stored.

teacher#

Parameter

Datatype

Default

Description

Supported Values

backbone

Dict config

Configuration parameters for Backbone

infrheadequent_logging_frequency

Dict config

Configuration parameters for Head

Training the model#

Use the tao model classification_pyt train command to train a classification pytorch model:

TRAIN_JOB_ID=$(tao-client classification_pyt experiment-run-action --action train --id $EXPERIMENT_ID --specs "$TRAIN_SPECS")

Evaluating the Model#

After the model has been trained using the experiment config file and by following the steps to train a model, the next step is to evaluate this model on a test set to measure the accuracy of the model. TAO includes the tao model classification_pyt evaluate command to do this.

The classification app computes evaluation loss and Top-k accuracy.

After training, the model is stored in your FTMS experiment’s cloud workspace. When using the TAO Launcher, it will be in the output directory of your choice results_dir.

EVAL_JOB_ID=$(tao-client classification_pyt experiment-run-action --action evaluate --id $EXPERIMENT_ID --specs "$TRAIN_SPECS" --previsou_job_id=$TRAIN_JOB_ID)

Running Inference on a Model#

For classification, tao model classification_pyt inference saves a .csv file containing the image paths and the corresponding labels for multiple images. TensorRT Python inference can also be enabled.

INFER_JOB_ID=$(tao-client classification_pyt experiment-run-action --action inference --id $EXPERIMENT_ID --specs "$INFER_SPECS" --previsou_job_id=$TRAIN_JOB_ID)

Exporting the model#

Exporting the model decouples the training process from inference and allows conversion to TensorRT engines outside the TAO environment. TensorRT engines are specific to each hardware configuration and should be generated for each unique inference environment. The exported model may be used universally across training and deployment hardware. The exported model format is referred to as .onnx.

EXPORT_JOB_ID=$(tao-client classification_pyt experiment-run-action --action export --id $EXPERIMENT_ID --specs "$EXPORT_SPECS" --previsou_job_id=$TRAIN_JOB_ID)

TensorRT Engine Generation, Validation, and INT8 Calibration#

For TensorRT engine generation, validation, and INT8 calibration, refer to the TAO Deploy documentation.

Deploying to DeepStream#

Refer to the Integrating a Classification (TF1/TF2/PyTorch) Model page for more information about deploying a classification model with DeepStream.