SteerLM 2.0: Iterative Training for Attribute-Conditioned Language Model Alignment

SteerLM 2.0 is a novel approach for aligning large language models (LLMs) to generate responses with desired attribute values, building upon the original SteerLM method 1 . While SteerLM conducts attribute-conditioned supervised fine-tuning to steer LLM outputs, SteerLM 2.0 introduces an iterative training procedure to explicitly enforce the generated responses to follow the desired attribute distribution.

Overview

The goal of SteerLM 2.0 is to train a model \(Q_\theta(y|a, x)\) that can generate responses \(y\) conditioned on a prompt \(x\) and desired attributes \(a\), while approximating the optimal conditional distribution \(P(y|a, x)\) derived from an attribute prediction model \(P(a|x, y)\) and an unconditional response model \(P(y|x)\). SteerLM 2.0 accomplishes this by minimizing the Kullback-Leibler (KL) divergence between \(P(y|a, x)\) and \(Q_\theta(y|a, x)\):

\[\min_\theta \mathbb{E}_{a, x} D_{KL}(P(y|a, x) || Q_\theta(y|a, x))\]

This KL divergence loss can be optimized using samples from an initial SteerLM model \(Q'(y|a, x)\), leading to an efficient gradient estimation procedure (see 2 for derivations).

Method Details

Construct the optimal conditional distribution \(P(y|a, x)\): Using Bayes’ rule and the attribute prediction model \(P(a|x, y)\), we can derive the optimal conditional distribution as:

\[P(y|a, x) \propto P(a|x, y) P(y|x)\]

Train the SteerLM 2.0 model \(Q_\theta(y|a, x)\): The SteerLM 2.0 model \(Q_\theta(y|a, x)\) is trained to approximate \(P(y|a, x)\) by minimizing the KL divergence loss using samples from an initial SteerLM model \(Q'(y|a, x)\). The gradient is estimated as:

\[\nabla_\theta L \approx -\sum_{y_i \sim Q'(y|a, x)} (w'_i - b'_i) \nabla_{\theta} \log Q_{\theta}(y_i|a, x)\]

where \(w'_i\) and \(b'_i\) are normalized importance weights targeting \(P(y|a, x)\) and a baseline for stable optimization, respectively.(see 2 for details).

Iterative Training (optional): SteerLM 2.0 can be conducted in iterations (e.g., \(n=2\)) using the optimized policy after each iteration to sample responses and train an improved policy. In each iteration, multiple diverse responses are sampled from the current model and used for the next round of training.

By iteratively training on this loss, SteerLM 2.0 can learn to generate responses \(y\) that better conform to specified attribute values \(a\) for a given prompt \(x\).

Train a SteerLM 2.0 Model

Preparing the Training Dataset

SteerLM 2.0 requires a specific data format to train the model effectively. According to the SteerLM 2.0 method, the following components are needed:

  • A supervised fine-tuning (SFT) model \(P(y|x)\) that generates responses \(y\) given a prompt \(x\)

  • An original SteerLM model \(Q'(y|a, x)\) that generates responses \(y\) conditioned on attributes \(a\) and prompt \(x\)

The SteerLM 2.0 model \(Q_\theta(y|a, x)\) is initialized with the weights from \(Q'(y|a, x)\) and optimized to approximate the optimal conditional distribution \(P(y|a, x)\) derived from the attribute prediction model \(P(a|x, y)\) and the unconditional response model \(P(y|x)\).

To facilitate this training process, a specific data format is proposed:

{
"system": "system prompt",
"prompt_turns": [
   {"from": "User", "value": "x_user_turn_1"},
   {"from": "Assistant", "value": "x_assistant_turn_1"},
   {"from": "User", "value": "x_user_turn_2"}
],
"label": "a",
"responses": [
   {
      "from": "Assistant",
      "value": "y_1",
      "log(P(a|x,y))": "v1",
      "log(P(y|x))": "v2",
      "log(Q(y|a,x))": "v3"
   },
   {
      "from": "Assistant",
      "value": "y_2",
      "log(P(a|x,y))": "v1",
      "log(P(y|x))": "v2",
      "log(Q(y|a,x))": "v3"
   },
   ...
   {
      "from": "Assistant",
      "value": "y_n",
      "log(P(a|x,y))": "v1",
      "log(P(y|x))": "v2",
      "log(Q(y|a,x))": "v3"
   }
]
}

For a given attribute string a and prompt x (constructed from prompt turns and the system turn), n responses \(y_i\) are sampled. To compute the loss, the following values are required:

  • \(\log P(a|y_i, x)\): The attribute prediction model’s output log-probability for the attributes a given the prompt x and response \(y_i\)

  • \(\log P(y_i|x)\): The unconditional response model’s output log-probability for the response \(y_i\) given the prompt x

  • \(\log Q'(y_i|a, x)\): The original SteerLM model’s output log-probability for the response \(y_i\) given the attributes a and prompt x

These values are provided as log(P(a|x,y)), log(P(y|x)), and log(Q(y|a,x)), respectively, for each sampled response \(y_i\).

Training Example

By organizing the data in this format, the SteerLM 2.0 model can be effectively trained to generate responses that conform to the desired attribute values while approximating the optimal conditional distribution \(P(y|a, x)\). Following is an example of launching the training of SteerLM 2.0:

python examples/nlp/gpt/train_steerlm2.py \
     trainer.num_nodes=32 \
     trainer.devices=8 \
     trainer.precision=bf16 \
     trainer.sft.limit_val_batches=40 \
     trainer.sft.max_epochs=1 \
     trainer.sft.max_steps=800 \
     trainer.sft.val_check_interval=800 \
     trainer.sft.save_interval=800 \
     model.megatron_amp_O2=True \
     model.restore_from_path=/models/llama70b \
     model.tensor_model_parallel_size=8 \
     model.pipeline_model_parallel_size=2 \
     model.optim.lr=6e-6 \
     model.optim.name=distributed_fused_adam \
     model.optim.weight_decay=0.01 \
     model.optim.sched.constant_steps=200 \
     model.optim.sched.warmup_steps=1 \
     model.optim.sched.min_lr=5e-6 \
     model.answer_only_loss=True \
     model.activations_checkpoint_granularity=selective \
     model.activations_checkpoint_method=uniform \
     model.steerlm2.micro_batch_size=2 \
     model.steerlm2.forward_micro_batch_size=2 \
     model.data.chat=True \
     model.data.num_workers=0 \
     model.data.chat_prompt_tokens.system_turn_start=\'\<extra_id_0\>\' \
     model.data.chat_prompt_tokens.turn_start=\'\<extra_id_1\>\' \
     model.data.chat_prompt_tokens.label_start=\'\<extra_id_2\>\' \
     model.data.train_ds.max_seq_length=4096 \
     model.data.train_ds.micro_batch_size=1 \
     model.data.train_ds.global_batch_size=128 \
     model.data.train_ds.file_path=data/oasst/train_labeled_2ep.jsonl \
     model.data.train_ds.index_mapping_dir=/indexmap_dir \
     model.data.train_ds.add_eos=False \
     model.data.train_ds.hf_dataset=True \
     model.data.validation_ds.max_seq_length=4096 \
     model.data.validation_ds.file_path=data/oasst/val_labeled.jsonl \
     model.data.validation_ds.micro_batch_size=1 \
     model.data.validation_ds.global_batch_size=128 \
     model.data.validation_ds.index_mapping_dir=/indexmap_dir \
     model.data.validation_ds.add_eos=False \
     model.data.validation_ds.hf_dataset=True \
     exp_manager.create_wandb_logger=True \
     exp_manager.wandb_logger_kwargs.project=steerlm \
     exp_manager.wandb_logger_kwargs.name=acsft_training \
     exp_manager.explicit_log_dir=/results/acsft_70b \
     exp_manager.checkpoint_callback_params.save_nemo_on_train_end=True

Inference

Since the SteerLM 2.0 Model is an extension of the original SteerLM model, the inference process is similar. Please refer to the SteerLM documentation for more details.

References

1

Dong, Y., Delalleau, O., Zeng, J., Shen, G., Zhang, J.J., Sreedhar, M.N., Kuchaiev, O. (2023). SteerLM: Attribute Conditioned SFT as an (User-Steerable) Alternative to RLHF.

2(1,2)

Wang, Z., Dong, Y., Delalleau, O., Zeng, J., Shen, G., Zhang, J.J., Sreedhar, M.N., Kuchaiev, O. (2024). HelpSteer2: Open-source dataset for training top-performing reward models.