Create spleen segmentation MMAR with MMAR API
from medl.tools.mmar_creator.train_config import TrainConfig
from medl.tools.mmar_creator.validate_config import ValidateConfig
from medl.tools.mmar_creator.component import Component
from medl.tools.mmar_creator.mmar import MMAR
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from ignite.metrics import Accuracy
from monai.data import CacheDataset, Dataset, DataLoader
from monai.engines import SupervisedTrainer, SupervisedEvaluator
from monai.handlers import (
LrScheduleHandler,
ValidationHandler,
CheckpointSaver,
StatsHandler,
TensorBoardStatsHandler,
MeanDice,
CheckpointLoader,
MetricsSaver,
)
from monai.inferers import SimpleInferer, SlidingWindowInferer
from monai.losses import DiceLoss
from monai.networks.nets import UNet
from monai.transforms import (
Activationsd,
AsDiscreted,
LoadImaged,
EnsureChannelFirstd,
ScaleIntensityRanged,
CropForegroundd,
RandCropByPosNegLabeld,
RandShiftIntensityd,
ToTensord,
Invertd,
SaveImaged,
)
def main():
# define train components
loss = Component(name=DiceLoss, args={"to_onehot_y": True, "softmax": True})
optimizer = Component(name=Adam, args={"params": "#@model.parameters()", "lr": "{learning_rate}"})
train_model = Component(name=UNet, args={"spatial_dims": 3, "in_channels": "{INPUT_CHANNELS}", "out_channels": "{OUTPUT_CHANNELS}", "channels": [16, 32, 64, 128, 256], "strides": [2, 2, 2, 2], "num_res_units": 2, "norm": "batch"})
lr_scheduler = Component(name=StepLR, args={"optimizer": "@optimizer", "step_size": 5000, "gamma": 0.1})
# define pre_transforms
load_image = Component(name=LoadImaged, args={"keys": ["image", "label"]})
ensure_channelfirst = Component(name=EnsureChannelFirstd, args={"keys": ["image", "label"]})
scale_intensity_range = Component(name=ScaleIntensityRanged, args={"keys": "image", "a_min": -57, "a_max": 164, "b_min": 0.0, "b_max": 1.0, "clip": True})
crop_foreground = Component(name=CropForegroundd, args={"keys": ["image", "label"], "source_key": "image"})
crop_foreground_eval = Component(name=CropForegroundd, args={"keys": "image", "source_key": "image"})
rand_crop_by_posneg_label = Component(name=RandCropByPosNegLabeld, args={"keys": ["image", "label"], "label_key": "label", "spatial_size": [96, 96, 96], "pos": 1, "neg": 1, "num_samples": 4, "image_key": "image", "image_threshold": 0})
rand_shift_intensity = Component(name=RandShiftIntensityd, args={"keys": "image", "offsets": 0.1, "prob": 0.5})
to_tensor = Component(name=ToTensord, args={"keys": ["image", "label"]})
# define train dataset and dataloader
train_dataset = Component(
name=CacheDataset,
vars={"data_list_file_path": "{DATASET_JSON}", "data_file_base_dir": "{DATA_ROOT}", "data_list_key": "{TRAIN_DATALIST_KEY}"},
args={"transform": "@pre_transforms", "cache_num": 32, "cache_rate": 1.0, "num_workers": 4},
)
train_dataloader = Component(name=DataLoader, args={"dataset": "@dataset", "batch_size": 2, "shuffle": True, "num_workers": 4})
# define train inferer
train_inferer = Component(name=SimpleInferer)
# define train handlers
lr_scheduler_handler = Component(name=LrScheduleHandler, args={"lr_scheduler": "@lr_scheduler", "print_lr": True})
validation_handler = Component(name=ValidationHandler, args={"validator": "@evaluator", "epoch_level": True, "interval": "{num_interval_per_valid}"})
train_checkpoint_saver = Component(name=CheckpointSaver, vars={"rank": 0}, args={"save_dir": "{MMAR_CKPT_DIR}", "save_dict": {"model": "@model", "optimizer": "@optimizer", "lr_scheduler": "@lr_scheduler", "train_conf": "@conf"}, "save_final": True, "save_interval": 400})
train_stats_handler = Component(name=StatsHandler, vars={"rank": 0}, args={"tag_name": "train_loss", "output_transform": "#monai.handlers.from_engine(['loss'], first=True)"})
train_tb_stats_handler = Component(name=TensorBoardStatsHandler, vars={"rank": 0}, args={"log_dir": "{MMAR_CKPT_DIR}", "tag_name": "train_loss", "output_transform": "#monai.handlers.from_engine(['loss'], first=True)"})
# define post transforms
activation = Component(name=Activationsd, args={"keys": "pred", "softmax": True})
invert = Component(name=Invertd, args={"keys": "pred", "transform": "@pre_transforms", "orig_keys": "image", "meta_keys": "pred_meta_dict", "nearest_interp": False, "to_tensor": True, "device": "cuda"})
as_discrete = Component(name=AsDiscreted, args={"keys": ["pred", "label"], "argmax": [True, False], "to_onehot": 2})
save_image = Component(name=SaveImaged, args={"keys": "pred", "meta_keys": "pred_meta_dict", "output_dir": "{MMAR_EVAL_OUTPUT_PATH}", "resample": False, "squeeze_end_dims": True})
# define train metrics
train_acc = Component(name=Accuracy, vars={"log_label": "train_acc"}, args={"output_transform": "#monai.handlers.from_engine(['pred', 'label'])"})
# define trainer
trainer = Component(
name=SupervisedTrainer,
args={
"max_epochs": "{epochs}",
"device": "cuda",
"train_data_loader": "@dataloader",
"network": "@model",
"loss_function": "@loss",
"optimizer": "@optimizer",
"inferer": "@inferer",
"postprocessing": "@post_transforms",
"key_train_metric": "@key_metric",
"train_handlers": "@handlers",
"amp": "{amp}",
}
)
# define validation dataset and dataloader
val_dataset = Component(
name=CacheDataset,
vars={"data_list_file_path": "{DATASET_JSON}", "data_file_base_dir": "{DATA_ROOT}", "data_list_key": "{VAL_DATALIST_KEY}"},
args={"transform": "@pre_transforms", "cache_num": 9, "cache_rate": 1.0, "num_workers": 4},
)
val_dataloader = Component(name=DataLoader, args={"dataset": "@dataset", "batch_size": 1, "shuffle": False, "num_workers": 4})
# define validation inferer
val_inferer = Component(name=SlidingWindowInferer, args={"roi_size": [160, 160, 160], "sw_batch_size": 4, "overlap": 0.5})
# define validation handlers
val_stats_handler = Component(name=StatsHandler, vars={"rank": 0}, args={"output_transform": "lambda x: None"})
val_tb_stats_handler = Component(name=TensorBoardStatsHandler, vars={"rank": 0}, args={"log_dir": "{MMAR_CKPT_DIR}", "output_transform": "lambda x: None"})
val_checkpoint_saver = Component(name=CheckpointSaver, vars={"rank": 0}, args={"save_dir": "{MMAR_CKPT_DIR}", "save_dict": {"model": "@model", "train_conf": "@conf"}, "save_key_metric": True})
# define validation metrics
val_mean_dice = Component(name=MeanDice, vars={"log_label": "val_mean_dice", "is_key_metric": True}, args={"include_background": False, "output_transform": "#monai.handlers.from_engine(['pred', 'label'])"})
val_acc = Component(name=Accuracy, vars={"log_label": "val_acc"}, args={"output_transform": "#monai.handlers.from_engine(['pred', 'label'])"})
# define evaluator
evaluator = Component(
name=SupervisedEvaluator,
args={
"device": "cuda",
"val_data_loader": "@dataloader",
"network": "@model",
"inferer": "@inferer",
"postprocessing": "@post_transforms",
"key_val_metric": "@key_metric",
"additional_metrics": "@additional_metrics",
"val_handlers": "@handlers",
"amp": "{amp}",
}
)
# define evaluation dataset
eval_dataset = Component(
name=Dataset,
vars={"data_list_file_path": "{DATASET_JSON}", "data_file_base_dir": "{DATA_ROOT}", "data_list_key": "{VAL_DATALIST_KEY}"},
args={"transform": "@pre_transforms"},
)
# define evaluation handlers
eval_checkpoint_loader = Component(name=CheckpointLoader, vars={"disabled": "{dont_load_ckpt_model}"}, args={"load_path": "{MMAR_CKPT}", "load_dict": {"model": "@model"}})
eval_metrics_saver = Component(name=MetricsSaver, args={"save_dir": "{MMAR_EVAL_OUTPUT_PATH}", "metrics": ["val_mean_dice", "val_acc"], "metric_details": ["val_mean_dice"], "batch_transform": "#monai.handlers.from_engine(['image_meta_dict'])", "summary_ops": "*", "save_rank": 0})
# create train config based on above components
train_conf = TrainConfig()
# set variables in config
train_conf.add_vars(epochs=1260, num_interval_per_valid=20, learning_rate=2e-4, multi_gpu=False, amp=True, cudnn_benchmark=False)
train_conf.add_loss(loss)
train_conf.add_optimizer(optimizer)
train_conf.add_model(train_model)
train_conf.add_lr_scheduler(lr_scheduler)
train_conf.add_train_pre_transform([
load_image,
ensure_channelfirst,
scale_intensity_range,
crop_foreground,
rand_crop_by_posneg_label,
rand_shift_intensity,
to_tensor,
])
train_conf.add_train_dataset(train_dataset)
train_conf.add_train_dataloader(train_dataloader)
train_conf.add_train_inferer(train_inferer)
train_conf.add_train_handler([
lr_scheduler_handler,
validation_handler,
train_checkpoint_saver,
train_stats_handler,
train_tb_stats_handler,
])
train_conf.add_train_post_transform([activation, as_discrete])
train_conf.add_train_key_metric(train_acc)
train_conf.add_trainer(trainer)
train_conf.add_val_pre_transform([
load_image,
ensure_channelfirst,
scale_intensity_range,
crop_foreground,
to_tensor,
])
train_conf.add_val_dataset(val_dataset)
train_conf.add_val_dataloader(val_dataloader)
train_conf.add_val_inferer(val_inferer)
train_conf.add_val_handler([
val_stats_handler,
val_tb_stats_handler,
val_checkpoint_saver,
])
train_conf.add_val_post_transform([activation, as_discrete])
train_conf.add_val_key_metric(val_mean_dice)
train_conf.add_val_additional_metric([val_acc])
train_conf.add_evaluator(evaluator)
# create validation config based on above components
eval_conf = ValidateConfig()
# set variables in the config
eval_conf.add_vars(multi_gpu=False, amp=True, dont_load_ts_model=False, dont_load_ckpt_model=True)
eval_conf.add_model([
Component(vars={"ts_path": "{MMAR_TORCHSCRIPT}", "disabled": "{dont_load_ts_model}"}),
Component(vars={"ckpt_path": "{MMAR_CKPT}", "disabled": "{dont_load_ckpt_model}"}),
])
eval_conf.add_pre_transform([
load_image,
ensure_channelfirst,
scale_intensity_range,
crop_foreground_eval,
to_tensor,
])
eval_conf.add_dataset(eval_dataset)
eval_conf.add_dataloader(val_dataloader)
eval_conf.add_inferer(val_inferer)
eval_conf.add_handler([
eval_checkpoint_loader,
val_stats_handler,
eval_metrics_saver,
])
eval_conf.add_post_transform([activation, invert, as_discrete, save_image])
eval_conf.add_key_metric(val_mean_dice)
eval_conf.add_additional_metric([val_acc])
eval_conf.add_evaluator(evaluator)
# create MMAR
mmar = MMAR(root_path="spleen_segmentation", datalist_name="dataset_0.json")
mmar.set_train_config(train_conf)
mmar.set_validate_config(eval_conf)
mmar.set_resources({"log.config": "[loggers]\nkeys=root,modelLogger\n[handlers]\nkeys=consoleHandler\n[formatters]\nkeys=fullFormatter\n[logger_root]\nlevel=INFO\nhandlers=consoleHandler\n[logger_modelLogger]\nlevel=DEBUG\nhandlers=consoleHandler\nqualname=modelLogger\npropagate=0\n[handler_consoleHandler]\nclass=StreamHandler\nlevel=DEBUG\nformatter=fullFormatter\nargs=(sys.stdout,)\n[formatter_fullFormatter]\nformat=%(asctime)s-%(name)s-%(levelname)s-%(message)s"})
mmar.set_environment(env={
"CLARA_TRAIN_VERSION": "4.1.0",
"DATA_ROOT": "/workspace/data/Task09_Spleen_nii",
"DATASET_JSON": "config/dataset_0.json",
"PROCESSING_TASK": "segmentation",
"TRAIN_DATALIST_KEY": "training",
"VAL_DATALIST_KEY": "validation",
"INFER_DATALIST_KEY": "test",
"MMAR_EVAL_OUTPUT_PATH": "eval",
"MMAR_CKPT_DIR": "models",
"MMAR_CKPT": "models/model.pt",
"MMAR_TORCHSCRIPT": "models/model.ts",
"INPUT_CHANNELS": 1,
"OUTPUT_CHANNELS": 2
})
mmar.set_datalist(
datalist={
"training": [
{"image": "imagesTr/spleen_29.nii", "label": "labelsTr/spleen_29.nii"},
{"image": "imagesTr/spleen_46.nii", "label": "labelsTr/spleen_46.nii"},
{"image": "imagesTr/spleen_25.nii", "label": "labelsTr/spleen_25.nii"},
{"image": "imagesTr/spleen_13.nii", "label": "labelsTr/spleen_13.nii"},
{"image": "imagesTr/spleen_62.nii", "label": "labelsTr/spleen_62.nii"},
{"image": "imagesTr/spleen_27.nii", "label": "labelsTr/spleen_27.nii"},
{"image": "imagesTr/spleen_44.nii", "label": "labelsTr/spleen_44.nii"},
{"image": "imagesTr/spleen_56.nii", "label": "labelsTr/spleen_56.nii"},
{"image": "imagesTr/spleen_60.nii", "label": "labelsTr/spleen_60.nii"},
{"image": "imagesTr/spleen_2.nii", "label": "labelsTr/spleen_2.nii"},
{"image": "imagesTr/spleen_53.nii", "label": "labelsTr/spleen_53.nii"},
{"image": "imagesTr/spleen_41.nii", "label": "labelsTr/spleen_41.nii"},
{"image": "imagesTr/spleen_22.nii", "label": "labelsTr/spleen_22.nii"},
{"image": "imagesTr/spleen_14.nii", "label": "labelsTr/spleen_14.nii"},
{"image": "imagesTr/spleen_18.nii", "label": "labelsTr/spleen_18.nii"},
{"image": "imagesTr/spleen_20.nii", "label": "labelsTr/spleen_20.nii"},
{"image": "imagesTr/spleen_32.nii", "label": "labelsTr/spleen_32.nii"},
{"image": "imagesTr/spleen_16.nii", "label": "labelsTr/spleen_16.nii"},
{"image": "imagesTr/spleen_12.nii", "label": "labelsTr/spleen_12.nii"},
{"image": "imagesTr/spleen_63.nii", "label": "labelsTr/spleen_63.nii"},
{"image": "imagesTr/spleen_28.nii", "label": "labelsTr/spleen_28.nii"},
{"image": "imagesTr/spleen_24.nii", "label": "labelsTr/spleen_24.nii"},
{"image": "imagesTr/spleen_59.nii", "label": "labelsTr/spleen_59.nii"},
{"image": "imagesTr/spleen_47.nii", "label": "labelsTr/spleen_47.nii"},
{"image": "imagesTr/spleen_8.nii", "label": "labelsTr/spleen_8.nii"},
{"image": "imagesTr/spleen_6.nii", "label": "labelsTr/spleen_6.nii"},
{"image": "imagesTr/spleen_61.nii", "label": "labelsTr/spleen_61.nii"},
{"image": "imagesTr/spleen_10.nii", "label": "labelsTr/spleen_10.nii"},
{"image": "imagesTr/spleen_38.nii", "label": "labelsTr/spleen_38.nii"},
{"image": "imagesTr/spleen_45.nii", "label": "labelsTr/spleen_45.nii"},
{"image": "imagesTr/spleen_26.nii", "label": "labelsTr/spleen_26.nii"},
{"image": "imagesTr/spleen_49.nii", "label": "labelsTr/spleen_49.nii"},
],
"validation": [
{"image": "imagesTr/spleen_19.nii", "label": "labelsTr/spleen_19.nii"},
{"image": "imagesTr/spleen_31.nii", "label": "labelsTr/spleen_31.nii"},
{"image": "imagesTr/spleen_52.nii", "label": "labelsTr/spleen_52.nii"},
{"image": "imagesTr/spleen_40.nii", "label": "labelsTr/spleen_40.nii"},
{"image": "imagesTr/spleen_3.nii", "label": "labelsTr/spleen_3.nii"},
{"image": "imagesTr/spleen_17.nii", "label": "labelsTr/spleen_17.nii"},
{"image": "imagesTr/spleen_21.nii", "label": "labelsTr/spleen_21.nii"},
{"image": "imagesTr/spleen_33.nii", "label": "labelsTr/spleen_33.nii"},
{"image": "imagesTr/spleen_9.nii", "label": "labelsTr/spleen_9.nii"},
],
}
)
mmar.set_commands({
"set_env.sh": "#!/usr/bin/env bash\nexport PYTHONPATH=\"$PYTHONPATH:/opt/nvidia\"\nDIR=\"$( cd\"$( dirname\"${BASH_SOURCE[0]}\")\">/dev/null 2>&1 && pwd )\"\nexport MMAR_ROOT=${DIR}/..",
"train.sh": "#!/usr/bin/env bash\nmy_dir=\"$(dirname\"$0\")\"\n. $my_dir/set_env.sh\necho\"MMAR_ROOT set to $MMAR_ROOT\"\nadditional_options=\"$*\"\nCONFIG_FILE=config/config_train.json\nENVIRONMENT_FILE=config/environment.json\npython3 -u -m medl.apps.train -m $MMAR_ROOT -c $CONFIG_FILE -e $ENVIRONMENT_FILE --write_train_stats --set print_conf=True epochs=1260 learning_rate=0.0002 num_interval_per_valid=20 multi_gpu=False cudnn_benchmark=False dont_load_ckpt_model=True ${additional_options}",
"validate_ckpt.sh": "#!/usr/bin/env bash\nmy_dir=\"$(dirname\"$0\")\"\n. $my_dir/set_env.sh\necho\"MMAR_ROOT set to $MMAR_ROOT\"\nadditional_options=\"$*\"\nCONFIG_FILE=config/config_validation.json\nENVIRONMENT_FILE=config/environment.json\npython3 -u -m medl.apps.evaluate -m $MMAR_ROOT -c $CONFIG_FILE -e $ENVIRONMENT_FILE --set print_conf=True multi_gpu=False dont_load_ts_model=True dont_load_ckpt_model=False ${additional_options}",
})
# FIXME: to keep this example simple, we don't add commands, docs, inference, etc.
# build MMAR
mmar.build()
# test load MMAR
test_mmar = MMAR()
test_mmar.load(root_path="spleen_segmentation")
print(test_mmar.get_train_config().get_config())
print(test_mmar.get_validate_config().get_config())
if __name__ == "__main__":
main()