3. Video Classification and ASR with HuggingFace Accelerate on DGX Cloud

3.1. Overview

This guide demonstrates how to enable distributed training of two different models with HuggingFace Accelerate. The first example is to fine-tune a video classification model with data processing and augmentation based on PyTorchVideo. The second one is quantized LoRA (QLoRA) fine-tuning of whisper-large-v2, a high-performance automatic speech recognition (ASR) model. These examples represent training throughput scalability in a Slurm cluster, and the importance of data pre-processing for efficient training.

3.2. Prerequisites

Please check that the following prerequisites are met before using this guide.

  1. A DGX Cloud Slurm cluster is provisioned and the user has access to launching and running jobs on the cluster (root permissions not necessary).

  2. The user has access to at least two A100 or H100-based compute nodes on the cluster that can run jobs with high-speed multi-node compute networks.

  3. The user has read/write access to at least 100GB of shared storage which is mounted and available on all nodes in the cluster and available within jobs. To identify the shared storage path, please consult with your cluster administrator. The path will be represented by <SHARED_STORAGE_ROOT> in this document. An example for <SHARED_STORAGE_ROOT> might be /lustre/fs0/scratch/demo-user.

  4. The cluster has external internet access on all nodes to download datasets and pre-trained models.

  5. The user has a valid Kaggle account and API token. This is required for our ASR example. To obtain a Kaggle API token, please refer to this Kaggle documentation.

  6. The cluster login node has a Python environment available that allows users to install their own packages locally.

3.2.1. Preparing a Customized Container Image

The Slurm implementation in your DGX Cloud cluster leverages Pyxis and Enroot as the container plugin and runtime. This section explains how to build a customized environment into a container image for use on BCM-based clusters with Slurm and Pyxis/Enroot. The build process should be completed in a local Linux environment. By building a containerized environment, we have a frozen environment to perform model training whenever possible without concerns about package compatibility.

3.2.1.1. Creating a Container Image in a Local Linux Machine

To build a customized container environment for this workload, we need to install the Docker engine on the local machine, an open-source containerization technology for building and containerizing your applications. Once Docker is installed, we can pull container images and build a container environment to support the model training by providing a list of instructions on how to assemble an image. First, we create a Dockerfile with content as follows:

# Dockerfile
FROM nvcr.io/nvidia/pytorch:24.01-py3

RUN pip install lightning==2.2.1 \
    transformers==4.39.3 \
    evaluate==0.4.1 \
    accelerate==0.29.2 \
    jiwer==3.0.3 \
    bitsandbytes==0.43.1 \
    peft==0.10.0 \
    librosa==0.10.1 \
    datasets==2.18.0 \
    opendatasets==0.1.22 \
    gradio==4.26.0

RUN cd /opt && \
      git clone https://github.com/facebookresearch/pytorchvideo.git && \
      cd /opt/pytorchvideo && \
      git checkout 1fadaef40dd393ca09680f55582399f4679fc9b7 && \
      pip install -e .

In this Dockerfile, we select a PyTorch container image from NGC as the base image (nvcr.io/nvidia/pytorch:24.01-py3). We use pip to install the HuggingFace Transformers library and other dependencies. Next, we pull the PyTorch repository from GitHub and checkout to a fixed commit. Execute the following command in the same folder with the Dockerfile, and a container image pytorch-vidasr:24.01 will be created. Note that the “dot” in the command line is necessary.

docker build -t pytorch-vidasr:24.01 .

Once the container build is completed, we can execute the following command, and the container image will be listed and will look like this.

1docker images | grep vidasr
2REPOSITORY      TAG    IMAGE ID       CREATED          SIZE
3pytorch-vidasr  24.01  45909445a2ee   11 minutes ago   22.7GB

3.2.1.2. Setting Up Your Cluster Workspace

Before model training, we must prepare scripts, dataset, and the containerized environment as detailed in the previous step. Please refer to these details in the Cluster User Guide Setting Up NGC Integration regarding container and Enroot setup.

From here, we have two options to use this image in the Slurm cluster.

3.2.1.2.1. Option 1: Push to an Accessible Container Registry

A container registry is a repository or repository collection to store container images. The NGC Catalog is an example of a container registry. If we have a container registry (denoted as <YOUR_REGISTRY>) with image upload permission that is accessible by the cluster, we can push the image to the registry, where Pyxis/Enroot in the Slurm cluster can pull the image from it. If the registry requires login and password authentication, execute docker login <YOUR_REGISTRY> and login with credentials first. Next, you can tag the image and push it to your registry. Note that you may need to tag the container image with a different name or additional repository path. For the sake of brevity, we use <YOUR_REGISTRY>/pytorch-vidasr:24.01 from now on.

1docker tag pytorch-vidasr:24.01 <YOUR_REGISTRY>/pytorch-vidasr:24.01
2docker push <YOUR_REGISTRY>/pytorch-vidasr:24.01
3.2.1.2.2. Option 2: Convert to a SquashFS File and Upload to the Slurm Cluster

If uploading our container image to a container registry is unavailable, we will need to convert the image into a SquashFS file using NVIDIA Enroot. This tool turns traditional container/OS images into unprivileged sandboxes. First, we will check the current version of Enroot in the cluster. Login to the cluster login node and execute enroot version to confirm the current version of Enroot in the cluster. (3.4.1 as of March 11, 2024)

Next, follow the instructions here to install the Enroot version on your local machine that corresponds to the version in the cluster (to ensure compatability). Once installation is completed, we can convert the pytorch-vidasr:24.01 image to a SquashFS file pytorch-vidasr-24.01.sqsh on our local machine with the following command.

enroot import -o pytorch-vidasr-24.01.sqsh dockerd://pytorch-vidasr:24.01

Once the conversion is completed, we can upload the SquashFS file to the Slurm cluster using one of the methods described in the Cluster User Guide Moving Data from Your Local Workstation to DGX Cloud. Note that the final destination of the SquashFS file in the Slurm cluster must be in a shared file system so that all compute nodes can access it when the distributed workload is launched. For example, if scp is used to upload the SquashFS file here, we can execute the following command on the local machine:

1scp pytorch-vidasr-24.01.sqsh \
2      <USERNAME>@<LOGIN_NODE>:<SHARED_STORAGE_ROOT>/

where <USERNAME> is the user name in the cluster and <LOGIN_NODE> is the node address of a login node in the DGX Cloud cluster.

3.3. Running on Slurm

This section covers cluster workspace preparation, Slurm batch script configuration, and checking multi-node training functionality.

3.3.1. Enabling Slurm Commands

Both of the use cases require Slurm. If Slurm commands are not enabled yet, execute the following command.

module load slurm

3.3.2. Use Case 1: Fine-Tuning a Video Classification Model with Slurm

Use SSH to access the cluster login node, where we will use the shell to execute various steps.

3.3.2.1. Workspace and Video Dataset Preparation

We first create directories in the shared scratch space as our workspace with the following command.

mkdir -p <SHARED_STORAGE_ROOT>/videocls/hf_workspace

Next, we use a subset of the UCF101 dataset for a basic test, which can be downloaded from the HuggingFace dataset repository. We first install the huggingface_hub Python package in the login node so that we can use it to download a dataset later.

module load python3
pip install huggingface_hub

Now we can create a data preparation file with the contents below and save it to our workspace directory.

# <SHARED_STORAGE_ROOT>/videocls/data_prep.py
from huggingface_hub import hf_hub_download
import os
import pathlib

hf_dataset_identifier = "sayakpaul/ucf101-subset"
filename = "UCF101_subset.tar.gz"
file_path = hf_hub_download(repo_id=hf_dataset_identifier,
                filename=filename,
                repo_type="dataset")
os.system("tar xf %s" % file_path)

Next, we execute the script in our training folder.

cd <SHARED_STORAGE_ROOT>/videocls/hf_workspace
python ../data_prep.py

3.3.2.2. Training Script

ViViT is a Transformer-based model for video classification from Google. It extracts spatio-temporal tokens from the input video and handles long sequences by factorizing the spatial and temporal dimensions of the input. This aspect makes it especially compelling for showcasing scaling of data distribution, as it can effectively handle very long video sequences. The Training script uses supervised fine-tuning on a pre-trained ViViT base model. The tuning is performed on a subset of the UCF101 dataset, including ten unique classes with 30 videos in each. The key steps in the script flow:

  1. Preprocess and augment (scale, crop, flip, resize, subsample, etc.) the videos using the PyTorchVideo library.

  2. Train the model using data-parallel distribution.

Save the following code to a script at <SHARED_STORAGE_ROOT>/videocls/hf_workspace/train_vivit.py.

# <SHARED_STORAGE_ROOT>/videocls/hf_workspace/train_vivit.py
import pathlib
import pytorchvideo.data
from pytorchvideo.transforms import (
      ApplyTransformToKey,
      Normalize,
      RandomShortSideScale,
      RemoveKey,
      ShortSideScale,
      UniformTemporalSubsample,
)

from torchvision.transforms import (
      Compose,
      Lambda,
      RandomCrop,
      RandomHorizontalFlip,
      Resize,
)

from transformers import VivitImageProcessor, VivitForVideoClassification
from transformers import TrainingArguments, Trainer
import evaluate
import torch
from torch.utils.data import SequentialSampler
import os
import numpy as np

from accelerate import Accelerator
from accelerate.data_loader import IterableDatasetShard

def preprocess_dataset(dataset_root_path, image_processor, num_frames_to_sample):
      mean = image_processor.image_mean
      std = image_processor.image_std
      if "shortest_edge" in image_processor.size:
            height = width = image_processor.size["shortest_edge"]
      else:
            height = image_processor.size["height"]
            width = image_processor.size["width"]
      resize_to = (height, width)

      sample_rate = 1
      fps = 30
      clip_duration = num_frames_to_sample * sample_rate / fps

      # Training dataset transformations
      train_transform = Compose(
                  [
                  ApplyTransformToKey(
                        key="video",
                        transform=Compose(
                              [
                              UniformTemporalSubsample(num_frames_to_sample),
                              Lambda(lambda x: x / 255.0),
                              Normalize(mean, std),
                              RandomShortSideScale(min_size=256, max_size=320),
                              RandomCrop(resize_to),
                              RandomHorizontalFlip(p=0.5),
                              ]
                              ),
                        ),
                  ]
                  )

      # Training dataset
      train_dataset = pytorchvideo.data.Ucf101(
                  data_path=os.path.join(dataset_root_path, "train"),
                  clip_sampler=pytorchvideo.data.make_clip_sampler("random", clip_duration),
                  decode_audio=False,
                  transform=train_transform,
                  )

      # Validation and evaluation datasets' transformations
      val_transform = Compose(
            [
                  ApplyTransformToKey(
                  key="video",
                  transform=Compose(
                        [
                              UniformTemporalSubsample(num_frames_to_sample),
                              Lambda(lambda x: x / 255.0),
                              Normalize(mean, std),
                              Resize(resize_to),
                              ]
                        ),
                  ),
                  ]
            )

      # Validation and evaluation datasets
      val_dataset = pytorchvideo.data.Ucf101(
                  data_path=os.path.join(dataset_root_path, "val"),
                  clip_sampler=pytorchvideo.data.make_clip_sampler("uniform", clip_duration),
                  decode_audio=False,
                  transform=val_transform,
                  )

      test_dataset = pytorchvideo.data.Ucf101(
                  data_path=os.path.join(dataset_root_path, "test"),
                  clip_sampler=pytorchvideo.data.make_clip_sampler("uniform", clip_duration),
                  decode_audio=False,
                  transform=val_transform,
                  )

      return train_dataset, val_dataset, test_dataset

accelerator = Accelerator()
print("Process ID: %d of %d" % (accelerator.process_index, accelerator.num_processes))
print("Available GPU devices: %d" % torch.cuda.device_count())

dataset_root_path = "UCF101_subset"
model_ckpt = "google/vivit-b-16x2-kinetics400" # pre-trained model from which to fine-tune
batch_size = 4 # Per-device batch size for training and evaluation

image_processor = VivitImageProcessor.from_pretrained(model_ckpt)

dataset_root_path = pathlib.Path(dataset_root_path)
video_count_train = len(list(dataset_root_path.glob("train/*/*.avi")))
video_count_val = len(list(dataset_root_path.glob("val/*/*.avi")))
video_count_test = len(list(dataset_root_path.glob("test/*/*.avi")))
video_total = video_count_train + video_count_val + video_count_test
print(f"Total videos: {video_total}")

all_video_file_paths = (
list(dataset_root_path.glob("train/*/*.avi"))
      + list(dataset_root_path.glob("val/*/*.avi"))
      + list(dataset_root_path.glob("test/*/*.avi"))
)

class_labels = sorted({str(path).split("/")[2] for path in all_video_file_paths})
label2id = {label: i for i, label in enumerate(class_labels)}
id2label = {i: label for label, i in label2id.items()}

print(f"Unique classes: {list(label2id.keys())}.")
model = VivitForVideoClassification.from_pretrained(
      model_ckpt,
      label2id=label2id,
      id2label=id2label,
      ignore_mismatched_sizes=True,  # provide this in order to fine-tune an already fine-tuned checkpoint
)
train_dataset, val_dataset, test_dataset = (
    preprocess_dataset(dataset_root_path, image_processor, model.config.num_frames)
    )

# Training setup

model_name = model_ckpt.split("/")[-1]
new_model_name = ("%s-finetuned-ucf101-subset-%s-n-%s-g-%d-b" %
                  (model_name, os.getenv("SLURM_NNODES"), os.getenv("SLURM_GPUS_PER_NODE"), batch_size))

args = TrainingArguments(
      new_model_name,
      remove_unused_columns=False,
      evaluation_strategy="epoch",
      save_strategy="epoch",
      save_on_each_node=False,
      learning_rate=5e-5,
      per_device_train_batch_size=batch_size,
      per_device_eval_batch_size=batch_size,
      warmup_ratio=0.1,
      logging_steps=10,
      load_best_model_at_end=True,
      metric_for_best_model="accuracy",
      push_to_hub=False,
      dataloader_num_workers=15, # Set it to 1 for single preprocess worker
      dataloader_prefetch_factor=64,
      max_steps=(train_dataset.num_videos // batch_size)*2,
)

# Next, we need to define a function for how to compute the metrics from the predictions,
# which will just use the metric we'll load now. The only preprocessing we have to do
# is to take the argmax of our predicted logits:
metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
      """Computes accuracy on a batch of predictions."""
      predictions = np.argmax(eval_pred.predictions, axis=1)
      return metric.compute(predictions=predictions, references=eval_pred.label_ids)

def collate_fn(examples):
      """The collation function to be used by `Trainer` to prepare data batches."""
      # permute to (num_frames, num_channels, height, width)
      pixel_values = torch.stack(
            [example["video"].permute(1, 0, 2, 3) for example in examples]
      )
      labels = torch.tensor([example["label"] for example in examples])
      return {"pixel_values": pixel_values, "labels": labels}

trainer = Trainer(
      model,
      args,
      train_dataset=train_dataset,
      eval_dataset=val_dataset,
      tokenizer=image_processor,
      compute_metrics=compute_metrics,
      data_collator=collate_fn,
)

train_results = trainer.train()

trainer.save_model()
test_results = trainer.evaluate(test_dataset)
trainer.log_metrics("test", test_results)
trainer.save_metrics("test", test_results)
trainer.save_state()

Important variables are noted as follows.

  1. dataloader_num_workers is set to 15 in our training arguments. Since we use a dataset with preprocessing and augmentation based on the PyTorchVideo library, more workers are required to enhance GPU utilization. We also set it to 1 and performed another test as a comparison.

  2. We set a fixed value of max_steps to (train_dataset.num_videos / batch_size)*2, which is 150 in our case. Therefore, the completed training epoch will scale with the number of GPUs.

Note that the purpose of the provided script is to validate the data-parallel training function. To optimize for other datasets, developers can tune the training arguments and other parameters in the script as necessary.

3.3.2.3. Batch Submission Script

Now we prepare the batch script with the content below and save it in our workspace folder (<SHARED_STORAGE_ROOT>/videocls/train-vivit-hf.sh).

#!/bin/bash

##SBATCH --job-name
##SBATCH --nodes
##SBATCH --gpus-per-node
#SBATCH --account=<SLURM_ACCOUNT>
#SBATCH --output=%x_%j.out
#SBATCH --error=%x_%j.err
#SBATCH --partition=<SLURM_PARTITION>
#SBATCH --time=01:00:00
#SBATCH --exclusive
#SBATCH --ntasks-per-node=1

# Environment variables added for DGX Cloud
export OMPI_MCA_coll_hcoll_enable=0
export UCX_TLS=tcp
export UCX_NET_DEVICES=eth0
export CUDA_DEVICE_ORDER=PCI_BUS_ID
export NCCL_SOCKET_IFNAME=eth0
export NCCL_IB_PCI_RELAXED_ORDERING=1
export NCCL_TOPO_FILE=/cm/shared/etc/ndv4-topo.xml
export NCCL_DEBUG=INFO
export NCCL_PROTO=LL,LL128,Simple
export NCCL_ALGO=Tree,Ring,CollnetDirect,CollnetChain,NVLS
export MELLANOX_VISIBLE_DEVICES=all
export PMIX_MCA_gds=hash
export PMIX_MCA_psec=native

export SHARED_STORAGE_ROOT=<SHARED_STORAGE_ROOT>
export CONTAINER_WORKSPACE_MOUNT=$SHARED_STORAGE_ROOT/videocls/hf_workspace
export CONTAINER_IMAGE=$SHARED_STORAGE_ROOT/pytorch-vidasr-24.01.sqsh

export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
export MASTER_PORT=$(( RANDOM % (50000 - 30000 + 1 ) + 30000 ))
export GPUS_PER_NODE=$SLURM_GPUS_PER_NODE
export NNODES=$SLURM_NNODES
export NUM_PROCESSES=$(expr $NNODES \* $GPUS_PER_NODE)
export MULTIGPU_FLAG="--multi_gpu"

if [ $NNODES == "1" ]
then
        export MULTIGPU_FLAG=""
fi

echo "MASTER_ADDR: $MASTER_ADDR"
echo "MASTER_PORT: $MASTER_PORT"

srun -l --container-image $CONTAINER_IMAGE \
        --container-mounts $CONTAINER_WORKSPACE_MOUNT:/workspace \
        --container-workdir /workspace \
        --no-container-mount-home \
        bash -c 'accelerate launch  --main_process_ip ${MASTER_ADDR} \
                                    --main_process_port ${MASTER_PORT} \
                                    --machine_rank $SLURM_NODEID \
                                    $MULTIGPU_FLAG \
                                    --same_network \
                                    --num_processes $NUM_PROCESSES \
                                    --num_cpu_threads_per_process 4 \
                                    --num_machines $NNODES train_vivit.py'

There are notable variables to be configured in this batch script for different settings of the Slurm account, Slurm partition, container image, and resource preference.

  • <SLURM_ACCOUNT_NAME>: The Slurm account to be used for your project. Please consult with the cluster administrator or project manager to determine which account name to use.

  • <SHARED_STORAGE_ROOT>: The root path of user shared storage defined in earlier sections.

  • <SLURM_PARTITION>: The Slurm partition(s) to use for this job. Slurm partitions are defined by cluster administrators and designated for different purposes or accounts. Note that a partition with multi-node job and GPU support must be selected.

  • <CONTAINER_IMAGE>: The container image name or path to be used in Slurm with Pyxis/Enroot. This depends on the option in the previous section on container image build.

    • If Option 1 is used, we can replace it with <YOUR_REGISTRY>/pytorch-vidasr-24.01.sqsh

    • If Option 2 is used, we will replace it with our upload destination <SHARED_STORAGE_ROOT>/pytorch-vidasr-24.01.sqsh

Several variables, such as job-name, nodes, and gpus-per-node, are not set directly in this batch script, but we can assign them in the submission commands to exercise different resource configurations. In this example, we start from one node with 1 GPU, scaling to 2, 4, and 8 GPUs, and two 8-GPU nodes. The commands to assign these variables are listed in the table below.

sbatch run configurations

Number of nodes

GPUs per node

Job submission command (in the script folder <SHARED_STORAGE_ROOT>/videocls)

1

1

sbatch --job-name=vivit-train-acc-n1g1b4 --nodes=1 --gpus-per-node=1 train-vivit-hf.sh

1

2

sbatch --job-name=vivit-train-acc-n1g2b4 --nodes=1 --gpus-per-node=2 train-vivit-hf.sh

1

4

sbatch --job-name=vivit-train-acc-n1g2b4 --nodes=1 --gpus-per-node=4 train-vivit-hf.sh

1

8

sbatch --job-name=vivit-train-acc-n1g2b4 --nodes=1 --gpus-per-node=8 train-vivit-hf.sh

2

8

sbatch --job-name=vivit-train-acc-n1g2b4 --nodes=2 --gpus-per-node=8 train-vivit-hf.sh

3.3.2.4. Training Steps, Epochs, and Time

We can retrieve training epoch information and elapsed training time of each job using the trainer_state.json in the model saving folder listed in the table below. The epoch number can be obtained by looking for the final epoch value in log_history section. Note that we only check the integer part of this value.

# trainer_state.json snippet for 1-Node, 1-GPU
{
  .....
  "log_history": [
    {
      "epoch": 0.07,
      "grad_norm": 14.361384391784668,
      "learning_rate": 3.3333333333333335e-05,
      "loss": 2.4146,
      "step": 10
    },
    {
      "epoch": 0.13,
      "grad_norm": 11.673120498657227,
      "learning_rate": 4.814814814814815e-05,
      "loss": 1.6931,
      "step": 20
    },
    {
      "epoch": 0.2,
      "grad_norm": 8.026626586914062,
      "learning_rate": 4.4444444444444447e-05,
      "loss": 1.106,
      "step": 30
    },
    .....
    {
      "epoch": 1.43,
      "grad_norm": 0.2679580748081207,
      "learning_rate": 3.7037037037037037e-06,
      "loss": 0.0442,
      "step": 140
    },
    {
      "epoch": 1.5,
      "grad_norm": 0.6468069553375244,
      "learning_rate": 0.0,
      "loss": 0.0889,
      "step": 150
    },
    {
      "epoch": 1.5,
      "eval_accuracy": 1.0,
      "eval_loss": 0.0483052060008049,
      "eval_runtime": 14.0177,
      "eval_samples_per_second": 10.629,
      "eval_steps_per_second": 2.711,
      "step": 150
    },
    {
      "epoch": 1.5,
      "step": 150,
      "total_flos": 1.537335139321774e+18,
      "train_loss": 0.5025620261828104,
      "train_runtime": 150.2643,
      "train_samples_per_second": 3.993,
      "train_steps_per_second": 0.998
    },
    {
      "epoch": 1.5,
      "eval_accuracy": 1.0,
      "eval_loss": 0.04270438104867935,
      "eval_runtime": 32.402,
      "eval_samples_per_second": 10.709,
      "eval_steps_per_second": 2.685,
      "step": 150
    }
  ],
  .....

To find the training time, we can look for several final entries in log_history. The value is recorded in train_runtime in a unit of seconds. We also use different process numbers of data loaders for comparison. With only one data loader process, data processing becomes the major bottleneck of training throughput even with single-GPU training. Using 15 processes yields significant performance gains with up to 4-GPU training.

Run configuration and training epochs

Number of nodes

GPUs per node

Model save subfolder

Global batch size (\(B_g\))

Integer part of the last epoch logging (\(Ceil(150/Ceil(300/B_g))-1\))

1

1

vivit-b-16x2-kinetics400-finetuned-ucf101-subset-1-n-1-g-4-b/

4

1

1

2

vivit-b-16x2-kinetics400-finetuned-ucf101-subset-1-n-2-g-4-b/

8

3

1

4

vivit-b-16x2-kinetics400-finetuned-ucf101-subset-1-n-4-g-4-b/

16

7

1

8

vivit-b-16x2-kinetics400-finetuned-ucf101-subset-1-n-8-g-4-b/

32

14

2

8

vivit-b-16x2-kinetics400-finetuned-ucf101-subset-2-n-8-g-4-b/

64

29

Note that the subset of UCF101 dataset is a derivative of IterableDataset, in which case HuggingFace Accelerate will enable dispatch_batches mechanism by default for multi-GPU training. In other words, the designated number of data loader workers will handle all data processing and augmentation, then dispatch processed data to all GPUs. Further chunking of the dataset is recommended for practical use cases with a larger scale of dataset with more GPU resources.

Data loaders and training runtime

Number of nodes

GPUs per node

Training runtime with 15 data loader processes (seconds)

Training runtime with 1 data loader process (seconds)

1

1

110.8

150.3

1

2

128.0

300.3

1

4

172.7

619.1

1

8

324.2

1154.0

2

8

623.1

2374.5

NOTE: Timings shown are for reference only.

3.3.3. Use Case 2: QLoRA Fine-Tuning of ASR with Slurm

3.3.3.1. Workspace and ASR Dataset Preparation

We create a new folder in the shared scratch space as our workspace for the ASR example with the following command.

mkdir -p <SHARED_STORAGE_ROOT>/asr

Next, we install the opendatasets Python package in a login node so that we can use it to download a dataset later.

module load python3
pip install opendatasets

Now we can download a dataset from Kaggle. A small dataset bengali-ai-asr-10k is used in our example for a quick test.

cd <SHARED_STORAGE_ROOT>/asr
python -c "import opendatasets as od;\
      od.download(\"https://www.kaggle.com/datasets/nbroad/bengali-ai-asr-10k\")"
# Depending on your local Kaggle API setup, a prompt appears for Kaggle user name and key
# For example
# Please provide your Kaggle credentials to download this dataset. Learn more: http://bit.ly/kaggle-creds
# Your Kaggle username:

3.3.3.2. Training Script

The training script uses HuggingFace PEFT to tune Whisper on the Kaggle Bengali ASR dataset (1 GB). We will import the model in 8-bit and add the LoRA adapter. We will only retain the LoRA weights and train on a part of the training dataset. LoRA tuning keeps the original weights frozen and adapts the frozen weights by adding a low rank matrix to the original weights. We use a rank of size 16 for this script.

Whisper is a pre-trained model for ASR by OpenAI, and its architecture is a seq2seq model with an audio encoder and text decoder. The feature extractor turns the 1D audio signal into a log-mel spectrogram, while the encoder creates hidden states which are passed to the decoder to generate text. Unlike its predecessor ASR model, Whisper was pre-trained on a vast quantity of labeled audio transcription data (Wav2Vec2.0 was pre-trained on unlabeled data). Bengali is a good use case as based on the Whisper paper since Whisper wasn’t trained on much Bengali data. The key preprocessing step is the unique data collator, which dynamically pads all audio samples such that they have an identical input length of 30 seconds. The script can easily be modified to support larger ASR datasets to showcase further scaling of data distributed training (for example, the Bengali ASR 80GB or the librispeech-clean 30GB dataset).

Save the following script to the path <SHARED_STORAGE_ROOT>/asr/qlora-asr.py. Note that for the purpose of a quick test, we only run one training epoch and parquet in our job to observe multi-GPU scalability.

# <SHARED_STORAGE_ROOT>/asr/qlora-asr.py
from transformers import WhisperFeatureExtractor
from transformers import WhisperTokenizer
from transformers import WhisperProcessor
from transformers import WhisperForConditionalGeneration, BitsAndBytesConfig
from transformers import Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from peft import prepare_model_for_kbit_training
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model
import datasets
from datasets import DatasetDict, load_dataset
from pathlib import Path
import opendatasets as od
import os
import pandas
import evaluate

from accelerate import Accelerator, DistributedDataParallelKwargs

def make_inputs_require_grad(module, input, output):
    output.requires_grad_(True)

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor ([`WhisperProcessor`])
            The processor used for processing the data.
        decoder_start_token_id (`int`)
            The begin-of-sentence of the decoder.
        forward_attention_mask (`bool`)
            Whether to return attention_mask.
    """

    processor: Any

    def __call__(
        self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
    ) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need
        # different padding methods
        model_input_name = self.processor.model_input_names[0]
        input_features = [
            {model_input_name: feature[model_input_name]} for feature in features
        ]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.feature_extractor.pad(
            input_features, return_tensors="pt"
        )

        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

class SavePeftModelCallback(TrainerCallback):
    def on_save(
        self,
        args: TrainingArguments,
        state: TrainerState,
        control: TrainerControl,
        **kwargs,
    ):
        checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")

        peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
        kwargs["model"].save_pretrained(peft_model_path)

        pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
        if os.path.exists(pytorch_model_path):
            os.remove(pytorch_model_path)
        return control

ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
device_index = Accelerator(kwargs_handlers=[ddp_kwargs]).local_process_index

device_map = {"": device_index}

model_name_or_path = "openai/whisper-large-v2"
task = "transcribe"
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name_or_path)
tokenizer = WhisperTokenizer.from_pretrained(model_name_or_path, language='bn', task=task)
processor = WhisperProcessor.from_pretrained(model_name_or_path, language='bn', task=task)

data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
)

metric = evaluate.load("wer")
model = (WhisperForConditionalGeneration.from_pretrained(model_name_or_path,
         quantization_config=BitsAndBytesConfig(load_in_8bit=True), device_map=device_map))
print(model.hf_device_map)
model = prepare_model_for_kbit_training(model)
model.model.encoder.conv1.register_forward_hook(make_inputs_require_grad)

lora_config = LoraConfig(r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none")

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

vectorized_datasets = DatasetDict()
train_data_dir = os.getenv("ASR_DATASETS")
validation_data_dir = os.getenv("ASR_DATASETS")

train_files = list(map(str, Path(train_data_dir).glob("train*.parquet")))
vectorized_datasets["train"] = load_dataset("parquet", data_files=train_files[:1], split="train")

eval_files = list(map(str, Path(validation_data_dir).glob("eval*.parquet")))
vectorized_datasets["eval"] = load_dataset(
    "parquet", data_files=eval_files, split="train"
)

training_args = Seq2SeqTrainingArguments(
    # change to a repo name of your choice
    output_dir="lora/%s-%s-n-%s-g" % (train_data_dir, os.getenv("SLURM_NNODES"), os.getenv("SLURM_GPUS_PER_NODE")),
    report_to="none", ### comment this out to login to wandb
    per_device_train_batch_size=8,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    ddp_find_unused_parameters=False,
    learning_rate=1e-5,
    warmup_steps=50,
    num_train_epochs=1,
    evaluation_strategy="steps",
    fp16=True,
    gradient_checkpointing_kwargs={'use_reentrant':False},
    per_device_eval_batch_size=8,
    logging_steps=250,
    # required as the PeftModel forward doesn't have the signature of the wrapped model's forward
    remove_unused_columns=False,
    label_names=["labels"],  # same reason as above
)

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=vectorized_datasets["train"],
    eval_dataset=vectorized_datasets["eval"],
    data_collator=data_collator,
    tokenizer=processor.feature_extractor,
    callbacks=[SavePeftModelCallback],

)
model.config.use_cache = False

trainer.train()
trainer.save_model()

3.3.3.3. Batch Submission Script

Now prepare a batch script with content shown below and save it in our workspace folder (<SHARED_STORAGE_ROOT>/asr/train-whisper-qlora.sh)

#!/bin/bash

##SBATCH --job-name=asr
##SBATCH --nodes=2
##SBATCH --gpus-per-node=8
#SBATCH --account=<SLURM_ACCOUNT>
#SBATCH --output=%x_%j.out
#SBATCH --error=%x_%j.err
#SBATCH --partition=<SLURM_PARTITION>
#SBATCH --time=01:00:00
#SBATCH --exclusive
#SBATCH --ntasks-per-node=1

export OMPI_MCA_coll_hcoll_enable=0
export UCX_TLS=rc
export UCX_NET_DEVICES=mlx5_0:1,mlx5_1:1,mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1
export CUDA_DEVICE_ORDER=PCI_BUS_ID
export NCCL_SOCKET_IFNAME=eth0
export NCCL_IB_PCI_RELAXED_ORDERING=1
export NCCL_TOPO_FILE=/cm/shared/etc/ndv4-topo.xml
export NCCL_DEBUG=INFO
export NCCL_PROTO=LL,LL128,Simple
export NCCL_ALGO=Tree,Ring,CollnetDirect,CollnetChain,NVLS
export MELLANOX_VISIBLE_DEVICES=all
export PMIX_MCA_gds=hash
export PMIX_MCA_psec=native

export SHARED_STORAGE_ROOT=<SHARED_STORAGE_ROOT>
export CONTAINER_WORKSPACE_MOUNT=$SHARED_STORAGE_ROOT/asr
export CONTAINER_IMAGE=$SHARED_STORAGE_ROOT/pytorch-vidasr-24.01.sqsh

export ASR_DATASETS=bengali-ai-asr-10k

export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
export MASTER_PORT=$(( RANDOM % (50000 - 30000 + 1 ) + 30000 ))
export GPUS_PER_NODE=$SLURM_GPUS_PER_NODE
export NNODES=$SLURM_NNODES
export NUM_PROCESSES=$(expr $NNODES \* $GPUS_PER_NODE)
export MULTIGPU_FLAG="--multi_gpu"

if [ $NNODES == "1" ]
then
        export MULTIGPU_FLAG=""
fi

echo "MASTER_ADDR: $MASTER_ADDR"
echo "MASTER_PORT: $MASTER_PORT"
echo "Using $NNODES nodes, $NUM_PROCESSES GPUs total"

srun -l --container-image $CONTAINER_IMAGE \
        --container-mounts /cm/shared/etc:/cm/shared/etc,$CONTAINER_WORKSPACE_MOUNT:/workspace \
        --container-workdir /workspace \
        --no-container-mount-home \
        bash -c 'accelerate launch      --main_process_ip ${MASTER_ADDR} \
                                        --main_process_port ${MASTER_PORT} \
                                        --machine_rank $SLURM_NODEID \
                                        $MULTIGPU_FLAG \
                                        --same_network \
                                        --num_processes $NUM_PROCESSES \
                                        --num_cpu_threads_per_process 4 \
                                        --num_machines $NNODES qlora-asr.py'

As earlier in the Video ASR example, we need to configure some parts of the batch script for the Slurm environment.

  • <SLURM_ACCOUNT>: The Slurm account to be selected for the job.

  • <SLURM_PARTITION>: The Slurm partition to submit for the job.

  • <SHARED_STORAGE_ROOT>: The root of the user shared scratch space.

However, we leave the options for job-name, nodes, and gpus-per-node unset in this script and assign them by passing them as arguments to the sbatch command as shown in the following table.

Job submission configurations

Number of nodes

GPUs per node

Job submission command (In the script folder <SHARED_STORAGE_ROOT>/asr)

1

1

sbatch --job-name=asr-n1g1 --nodes=1 --gpus-per-node=1  train-whisper-qlora.sh

1

2

sbatch --job-name=asr-n1g2 --nodes=1 --gpus-per-node=2  train-whisper-qlora.sh

1

4

sbatch --job-name=asr-n1g4 --nodes=1 --gpus-per-node=4  train-whisper-qlora.sh

1

8

sbatch --job-name=asr-n1g8 --nodes=1 --gpus-per-node=8  train-whisper-qlora.sh

2

8

sbatch --job-name=asr-n2g8 --nodes=2 --gpus-per-node=8  train-whisper-qlora.sh

3.3.3.4. Training Steps, Epochs, and Time

The number of training steps will be inversely proportional to the number of GPUs (round to ceiling) with a fixed number of training epochs. Run the following command in the workspace folder to get the last section of the stderr of each job to obtain the training time and the number of training steps for each job.

cd <SHARED_STORAGE_ROOT>/asr
tail <job-name>_<job_id>.err

where <job-name> and <job_id> are the designated job name from the sbatch command and the job ID given by Slurm when submitted. The last line of the file should have a progress bar of 100% completion. The following is a result from a single-node, single-GPU job as an example.

100%|██████████| 125/125 [09:22<00:00,  4.50s/it]
  • 125/125: This is the completed/total number of training steps.

  • 09:22<00:00: The left number is the elapsed time, and the right number (right of the < symbol) is the estimated remaining training time. This example shows the final total elapsed time with no remaining training time.

The number of training steps and reference training timings are in the table below.

Job submission configurations and timings

Number of nodes

GPUs per node

Number of steps with 1 training epoch

Elapsed time of 1 training epoch (seconds)

1

1

125

371

1

2

63

183

1

4

32

93

1

8

16

49

2

8

8

25