Post-training for Action-Conditioned Video Prediction#

This section provides instructions for post-training Predict2 Video2World models for action-conditioned video prediction.

Set up the Video2World Model#

  1. Ensure you have the necessary hardware and software, as outlined on the Prerequisites page.

  2. Follow the Installation guide to download the Cosmos-Predict2 repo and set up the environment.

  3. Generate a Hugging Face access token. Set the access token permission to ‘Read’ (the default permission is ‘Fine-grained’).

  4. Log in to Hugging Face with the access token:

    huggingface-cli login
    
  5. Review and accept the Llama-Guard-3-8B terms.

Prepare the Data#

Download Bridge Training Dataset#

We use the train/validation splits of the Bridge dataset from IRASim for action-conditioned post-training. To download and prepare the dataset, run the following commands under the cosmos-predict2/ directory:

wget https://lf-robot-opensource.bytetos.com/obj/lab-robot-public/opensource_IRASim_v1/bridge_train_data.tar.gz
mv bridge_train_data.tar.gz datasets/
cd datasets
tar -xvzf bridge_train_data.tar.gz -C .
mv opensource_robotdata/bridge ./

Your dataset directory structure should look like this:

datasets/bridge/
├── annotations/
│   ├── *.json
├── videos/
    ├── *.mp4

Each JSON file in the annotations/ folder contains the end-effector pose and gripper width of the robot arm for each frame in the corresponding video. Specifically, each file includes:

  • state: The end-effector pose of the robot arm at each timestep, represented as [x, y, z, roll, pitch, yaw].

    • (x, y, z) denotes the gripper’s position in world coordinates.

    • (roll, pitch, yaw) describes its orientation in Euler angles.

  • continuous_gripper_state: The width of the gripper at each timestep, indicating whether it is open or closed. A value of 0 means the gripper is open, and 1 means it is closed.

  • action: The gripper’s displacement at each timestep.

    • The first six dimensions represent displacement in (x, y, z, roll, pitch, yaw) within the gripper coordinate frame.

    • The last (seventh) dimension is a binary value indicating whether the gripper should open (1) or close (0).

We use this information as conditioning input for video generation.

Post-training#

2.1. Cosmos-Predict2-2B-Video2World#

Run the following command to launch an example post-training job using the Bridge dataset:

torchrun --nproc_per_node=2 --master_port=12341 -m scripts.train --config=cosmos_predict2/configs/base/config.py -- experiment="action_conditioned_predict2_video2world_2b_training"

See cosmos_predict2/configs/action_conditioned/defaults/data.py to understand how the dataloader is defined. To add action as additional condition, we create new conditioner to support action in cosmos_predict2/configs/action_conditioned/defaults/conditioner.py.

Checkpoint Output Structure#

Checkpoints are saved to the following path:

checkpoints/PROJECT/GROUP/NAME

For the example command above:

  • PROJECT: posttraining

  • GROUP: video2world

  • NAME: action_conditioned_predict2_video2world_2b_training_${now:%Y-%m-%d}_${now:%H-%M-%S}

Configuration Snippet#

Below is a configuration snippet defining the experiment setup:

action_conditioned_predict2_video2world_2b_training = dict(
    defaults=[
        {"override /model": "action_conditioned_predict2_v2w_2b_fsdp"},
        {"override /optimizer": "fusedadamw"},
        {"override /scheduler": "lambdalinear"},
        {"override /ckpt_type": "standard"},
        {"override /data_train": "bridge_train"},
        "_self_",
    ],
    model=dict(
        config=dict(
            fsdp_shard_size=-1,
        )
    ),
    job=dict(group="debug", name="action_conditioned_predict2_video2world_2b_training_${now:%Y-%m-%d}_${now:%H-%M-%S}"),
)

Inference#

Cosmos-Predict2-2B-Video2World#

To run inference using a post-trained checkpoint (e.g., at 1000 iterations), use the command below. Specify the path to the checkpoint using the --dit_path argument:

CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python examples/action_video2world.py \
  --model_size 2B \
  --dit_path "checkpoints/posttraining/video2world/action_conditioned_predict2_video2world_2b_training_${now:%Y-%m-%d}_${now:%H-%M-%S}/checkpoints/model/iter_000001000.pt" \
  --input_video datasets/bridge/videos/test/13/rgb.mp4 \
  --input_annotation datasets/bridge/annotation/test/13.json \
  --num_conditional_frames 1 \
  --save_path output/generated_video.mp4 \
  --guidance 0 \
  --seed 0 \
  --disable_guardrail \
  --disable_prompt_refiner