Post-training for Action-Conditioned Video Prediction#
This section provides instructions for post-training Predict2 Video2World models for action-conditioned video prediction.
Set up the Video2World Model#
Ensure you have the necessary hardware and software, as outlined on the Prerequisites page.
Follow the Installation guide to download the Cosmos-Predict2 repo and set up the environment.
Generate a Hugging Face access token. Set the access token permission to ‘Read’ (the default permission is ‘Fine-grained’).
Log in to Hugging Face with the access token:
huggingface-cli loginReview and accept the Llama-Guard-3-8B terms.
Prepare the Data#
Download Bridge Training Dataset#
We use the train/validation splits of the Bridge dataset from IRASim for action-conditioned post-training.
To download and prepare the dataset, run the following commands under the cosmos-predict2/ directory:
wget https://lf-robot-opensource.bytetos.com/obj/lab-robot-public/opensource_IRASim_v1/bridge_train_data.tar.gz
mv bridge_train_data.tar.gz datasets/
cd datasets
tar -xvzf bridge_train_data.tar.gz -C .
mv opensource_robotdata/bridge ./
Your dataset directory structure should look like this:
datasets/bridge/
├── annotations/
│ ├── *.json
├── videos/
├── *.mp4
Each JSON file in the annotations/ folder contains the end-effector pose and gripper width of the robot arm for each frame in the corresponding video.
Specifically, each file includes:
state: The end-effector pose of the robot arm at each timestep, represented as [x, y, z, roll, pitch, yaw].(x, y, z) denotes the gripper’s position in world coordinates.
(roll, pitch, yaw) describes its orientation in Euler angles.
continuous_gripper_state: The width of the gripper at each timestep, indicating whether it is open or closed. A value of 0 means the gripper is open, and 1 means it is closed.action: The gripper’s displacement at each timestep.The first six dimensions represent displacement in (x, y, z, roll, pitch, yaw) within the gripper coordinate frame.
The last (seventh) dimension is a binary value indicating whether the gripper should open (1) or close (0).
We use this information as conditioning input for video generation.
Post-training#
2.1. Cosmos-Predict2-2B-Video2World#
Run the following command to launch an example post-training job using the Bridge dataset:
torchrun --nproc_per_node=2 --master_port=12341 -m scripts.train --config=cosmos_predict2/configs/base/config.py -- experiment="action_conditioned_predict2_video2world_2b_training"
See cosmos_predict2/configs/action_conditioned/defaults/data.py to understand how the dataloader is defined.
To add action as additional condition, we create new conditioner to support action in cosmos_predict2/configs/action_conditioned/defaults/conditioner.py.
Checkpoint Output Structure#
Checkpoints are saved to the following path:
checkpoints/PROJECT/GROUP/NAME
For the example command above:
PROJECT:
posttrainingGROUP:
video2worldNAME:
action_conditioned_predict2_video2world_2b_training_${now:%Y-%m-%d}_${now:%H-%M-%S}
Configuration Snippet#
Below is a configuration snippet defining the experiment setup:
action_conditioned_predict2_video2world_2b_training = dict(
defaults=[
{"override /model": "action_conditioned_predict2_v2w_2b_fsdp"},
{"override /optimizer": "fusedadamw"},
{"override /scheduler": "lambdalinear"},
{"override /ckpt_type": "standard"},
{"override /data_train": "bridge_train"},
"_self_",
],
model=dict(
config=dict(
fsdp_shard_size=-1,
)
),
job=dict(group="debug", name="action_conditioned_predict2_video2world_2b_training_${now:%Y-%m-%d}_${now:%H-%M-%S}"),
)
Inference#
Cosmos-Predict2-2B-Video2World#
To run inference using a post-trained checkpoint (e.g., at 1000 iterations), use the command below.
Specify the path to the checkpoint using the --dit_path argument:
CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python examples/action_video2world.py \
--model_size 2B \
--dit_path "checkpoints/posttraining/video2world/action_conditioned_predict2_video2world_2b_training_${now:%Y-%m-%d}_${now:%H-%M-%S}/checkpoints/model/iter_000001000.pt" \
--input_video datasets/bridge/videos/test/13/rgb.mp4 \
--input_annotation datasets/bridge/annotation/test/13.json \
--num_conditional_frames 1 \
--save_path output/generated_video.mp4 \
--guidance 0 \
--seed 0 \
--disable_guardrail \
--disable_prompt_refiner