Create spleen annotation MMAR with MMAR API
from ignite.metrics import Accuracy
from medl.tools.mmar_creator.component import (
CheckpointSaverComponent,
Component,
OptimizerComponent,
)
from medl.tools.mmar_creator.mmar import MMAR
from medl.tools.mmar_creator.utils import (
ConfVars,
create_train_config,
create_validate_config,
)
from monai.data import CacheDataset, DataLoader, Dataset
from monai.engines import SupervisedEvaluator, SupervisedTrainer
from monai.handlers import (
CheckpointLoader,
CheckpointSaver,
LrScheduleHandler,
MeanDice,
MetricsSaver,
StatsHandler,
TensorBoardStatsHandler,
ValidationHandler,
)
from monai.inferers import SimpleInferer
from monai.losses import DiceLoss
from monai.networks.nets import UNet
from monai.transforms import (
Activationsd,
AddExtremePointsChanneld,
AsDiscreted,
CopyItemsd,
CropForegroundd,
EnsureChannelFirstd,
Invertd,
LoadImaged,
RandFlipd,
RandRotate90d,
RandShiftIntensityd,
RandZoomd,
Resized,
ScaleIntensityRanged,
ToTensord,
)
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
DATALIST = {
"training": [
{
"image": "imagesTr/spleen_29.nii.gz",
"label": "labelsTr/spleen_29.nii.gz",
},
{
"image": "imagesTr/spleen_46.nii.gz",
"label": "labelsTr/spleen_46.nii.gz",
},
{
"image": "imagesTr/spleen_25.nii.gz",
"label": "labelsTr/spleen_25.nii.gz",
},
{
"image": "imagesTr/spleen_13.nii.gz",
"label": "labelsTr/spleen_13.nii.gz",
},
{
"image": "imagesTr/spleen_62.nii.gz",
"label": "labelsTr/spleen_62.nii.gz",
},
{
"image": "imagesTr/spleen_27.nii.gz",
"label": "labelsTr/spleen_27.nii.gz",
},
{
"image": "imagesTr/spleen_44.nii.gz",
"label": "labelsTr/spleen_44.nii.gz",
},
{
"image": "imagesTr/spleen_56.nii.gz",
"label": "labelsTr/spleen_56.nii.gz",
},
{
"image": "imagesTr/spleen_60.nii.gz",
"label": "labelsTr/spleen_60.nii.gz",
},
{
"image": "imagesTr/spleen_2.nii.gz",
"label": "labelsTr/spleen_2.nii.gz",
},
{
"image": "imagesTr/spleen_53.nii.gz",
"label": "labelsTr/spleen_53.nii.gz",
},
{
"image": "imagesTr/spleen_41.nii.gz",
"label": "labelsTr/spleen_41.nii.gz",
},
{
"image": "imagesTr/spleen_22.nii.gz",
"label": "labelsTr/spleen_22.nii.gz",
},
{
"image": "imagesTr/spleen_14.nii.gz",
"label": "labelsTr/spleen_14.nii.gz",
},
{
"image": "imagesTr/spleen_18.nii.gz",
"label": "labelsTr/spleen_18.nii.gz",
},
{
"image": "imagesTr/spleen_20.nii.gz",
"label": "labelsTr/spleen_20.nii.gz",
},
{
"image": "imagesTr/spleen_32.nii.gz",
"label": "labelsTr/spleen_32.nii.gz",
},
{
"image": "imagesTr/spleen_16.nii.gz",
"label": "labelsTr/spleen_16.nii.gz",
},
{
"image": "imagesTr/spleen_12.nii.gz",
"label": "labelsTr/spleen_12.nii.gz",
},
{
"image": "imagesTr/spleen_63.nii.gz",
"label": "labelsTr/spleen_63.nii.gz",
},
{
"image": "imagesTr/spleen_28.nii.gz",
"label": "labelsTr/spleen_28.nii.gz",
},
{
"image": "imagesTr/spleen_24.nii.gz",
"label": "labelsTr/spleen_24.nii.gz",
},
{
"image": "imagesTr/spleen_59.nii.gz",
"label": "labelsTr/spleen_59.nii.gz",
},
{
"image": "imagesTr/spleen_47.nii.gz",
"label": "labelsTr/spleen_47.nii.gz",
},
{
"image": "imagesTr/spleen_8.nii.gz",
"label": "labelsTr/spleen_8.nii.gz",
},
{
"image": "imagesTr/spleen_6.nii.gz",
"label": "labelsTr/spleen_6.nii.gz",
},
{
"image": "imagesTr/spleen_61.nii.gz",
"label": "labelsTr/spleen_61.nii.gz",
},
{
"image": "imagesTr/spleen_10.nii.gz",
"label": "labelsTr/spleen_10.nii.gz",
},
{
"image": "imagesTr/spleen_38.nii.gz",
"label": "labelsTr/spleen_38.nii.gz",
},
{
"image": "imagesTr/spleen_45.nii.gz",
"label": "labelsTr/spleen_45.nii.gz",
},
{
"image": "imagesTr/spleen_26.nii.gz",
"label": "labelsTr/spleen_26.nii.gz",
},
{
"image": "imagesTr/spleen_49.nii.gz",
"label": "labelsTr/spleen_49.nii.gz",
},
],
"validation": [
{
"image": "imagesTr/spleen_19.nii.gz",
"label": "labelsTr/spleen_19.nii.gz",
},
{
"image": "imagesTr/spleen_31.nii.gz",
"label": "labelsTr/spleen_31.nii.gz",
},
{
"image": "imagesTr/spleen_52.nii.gz",
"label": "labelsTr/spleen_52.nii.gz",
},
{
"image": "imagesTr/spleen_40.nii.gz",
"label": "labelsTr/spleen_40.nii.gz",
},
{
"image": "imagesTr/spleen_3.nii.gz",
"label": "labelsTr/spleen_3.nii.gz",
},
{
"image": "imagesTr/spleen_17.nii.gz",
"label": "labelsTr/spleen_17.nii.gz",
},
{
"image": "imagesTr/spleen_21.nii.gz",
"label": "labelsTr/spleen_21.nii.gz",
},
{
"image": "imagesTr/spleen_33.nii.gz",
"label": "labelsTr/spleen_33.nii.gz",
},
{
"image": "imagesTr/spleen_9.nii.gz",
"label": "labelsTr/spleen_9.nii.gz",
},
],
}
def main():
# env variables
env_vars = ConfVars(
{
"CLARA_TRAIN_VERSION": "4.1.0",
"DATA_ROOT": "/dataset",
"DATASET_JSON": "config/dataset_0.json",
"PROCESSING_TASK": "segmentation",
"TRAIN_DATALIST_KEY": "training",
"VAL_DATALIST_KEY": "validation",
"INFER_DATALIST_KEY": "test",
"MMAR_EVAL_OUTPUT_PATH": "eval",
"MMAR_CKPT_DIR": "models",
"MMAR_CKPT": "models/model.pt",
"MMAR_TORCHSCRIPT": "models/model.ts",
"INPUT_CHANNELS": 2,
"OUTPUT_CHANNELS": 2,
}
)
# train configuration variables
train_conf_vars = ConfVars(
{
"epochs": 2000,
"learning_rate": 2e-4,
"num_interval_per_valid": 10,
"multi_gpu": False,
"amp": True,
"determinism": {"random_seed": 0},
"cudnn_benchmark": False,
"dont_load_ckpt_model": True,
}
)
# define train components
loss = Component(name=DiceLoss, args={"to_onehot_y": True, "softmax": True})
optimizer = OptimizerComponent(
name=Adam,
args={
"lr": train_conf_vars.get("learning_rate"),
},
)
train_model = Component(
name=UNet,
args={
"spatial_dims": 3,
"in_channels": env_vars.get("INPUT_CHANNELS"),
"out_channels": env_vars.get("OUTPUT_CHANNELS"),
"channels": [16, 32, 64, 128, 256],
"strides": [2, 2, 2, 2],
"num_res_units": 2,
"norm": "batch",
},
)
lr_scheduler = Component(
name=StepLR, args={"optimizer": optimizer, "step_size": 5000, "gamma": 0.1}
)
# define pre_transforms
load_image = Component(name=LoadImaged, args={"keys": ["image", "label"]})
ensure_channelfirst = Component(
name=EnsureChannelFirstd, args={"keys": ["image", "label"]}
)
scale_intensity_range = Component(
name=ScaleIntensityRanged,
args={
"keys": "image",
"a_min": -57,
"a_max": 164,
"b_min": 0.0,
"b_max": 1.0,
"clip": True,
},
)
crop_foreground = Component(
name=CropForegroundd,
args={"keys": ["image", "label"], "source_key": "label", "margin": 20},
)
rand_shift_intensity = Component(
name=RandShiftIntensityd, args={"keys": "image", "offsets": 0.2, "prob": 0.5}
)
rand_flip = Component(
name=RandFlipd,
args={"keys": ["image", "label"], "prob": 0.5, "spatial_axis": 0},
)
rand_rotate = Component(
name=RandRotate90d,
args={"keys": ["image", "label"], "spatial_axes": [1, 2], "prob": 0.5},
)
rand_zoom = Component(
name=RandZoomd,
args={
"keys": ["image", "label"],
"min_zoom": 0.8,
"max_zoom": 1.2,
"mode": ["area", "nearest"],
"prob": 1.0,
},
)
resize = Component(
name=Resized,
args={
"keys": ["image", "label"],
"spatial_size": [128, 128, 128],
"mode": ["area", "nearest"],
},
)
add_extreme_points_channel = Component(
name=AddExtremePointsChanneld,
args={"keys": "image", "label_key": "label", "sigma": 3, "pert": 3},
)
add_extreme_points_channel_val = Component(
name=AddExtremePointsChanneld,
args={"keys": "image", "label_key": "label", "sigma": 3, "pert": 0},
)
to_tensor = Component(name=ToTensord, args={"keys": ["image", "label"]})
train_pre_transforms = [
load_image,
ensure_channelfirst,
scale_intensity_range,
crop_foreground,
rand_shift_intensity,
rand_flip,
rand_rotate,
rand_zoom,
resize,
add_extreme_points_channel,
to_tensor,
]
val_pre_transforms = [
load_image,
ensure_channelfirst,
scale_intensity_range,
crop_foreground,
resize,
add_extreme_points_channel_val,
to_tensor,
]
# define train dataset and dataloader
train_dataset = Component(
name=CacheDataset,
vars={
"data_list_file_path": env_vars.get("DATASET_JSON"),
"data_file_base_dir": env_vars.get("DATA_ROOT"),
"data_list_key": env_vars.get("TRAIN_DATALIST_KEY"),
},
args={
"transform": train_pre_transforms,
"cache_num": 32,
"cache_rate": 1.0,
"num_workers": 2,
},
)
train_dataloader = Component(
name=DataLoader,
args={
"dataset": train_dataset,
"batch_size": 12,
"shuffle": True,
"num_workers": 2,
},
)
# define train inferer
train_inferer = Component(name=SimpleInferer)
# define post transforms
activation = Component(name=Activationsd, args={"keys": "pred", "softmax": True})
as_discrete = Component(
name=AsDiscreted,
args={"keys": ["pred", "label"], "argmax": [True, False], "to_onehot": 2},
)
post_transforms = [activation, as_discrete]
# define train metrics
train_acc = Component(
name=Accuracy,
vars={"log_label": "train_acc"},
args={"output_transform": "#monai.handlers.from_engine(['pred', 'label'])"},
)
# define validation dataset and dataloader
val_dataset = Component(
name=CacheDataset,
vars={
"data_list_file_path": env_vars.get("DATASET_JSON"),
"data_file_base_dir": env_vars.get("DATA_ROOT"),
"data_list_key": env_vars.get("VAL_DATALIST_KEY"),
},
args={
"transform": val_pre_transforms,
"cache_num": 9,
"cache_rate": 1.0,
"num_workers": 2,
},
)
val_dataloader = Component(
name=DataLoader,
args={
"dataset": val_dataset,
"batch_size": 12,
"shuffle": False,
"num_workers": 2,
},
)
# define validation inferer
val_inferer = Component(name=SimpleInferer)
# define validation handlers
val_stats_handler = Component(
name=StatsHandler, vars={"rank": 0}, args={"output_transform": "lambda x: None"}
)
val_tb_stats_handler = Component(
name=TensorBoardStatsHandler,
vars={"rank": 0},
args={
"log_dir": env_vars.get("MMAR_CKPT_DIR"),
"output_transform": "lambda x: None",
},
)
val_checkpoint_saver = CheckpointSaverComponent(
name=CheckpointSaver,
vars={"rank": 0},
args={
"save_dir": env_vars.get("MMAR_CKPT_DIR"),
"save_dict": {"model": train_model},
"save_final": True,
"save_key_metric": True,
},
)
val_handlers = [
val_stats_handler,
val_tb_stats_handler,
val_checkpoint_saver,
]
# define validation metrics
val_mean_dice = Component(
name=MeanDice,
vars={"log_label": "val_mean_dice"},
args={
"include_background": False,
"output_transform": "#monai.handlers.from_engine(['pred', 'label'])",
},
)
val_acc = Component(
name=Accuracy,
vars={"log_label": "val_acc"},
args={"output_transform": "#monai.handlers.from_engine(['pred', 'label'])"},
)
val_additional_metrics = [val_acc]
# define train validator
train_validator = Component(
name=SupervisedEvaluator,
args={
"device": "cuda",
"val_data_loader": val_dataloader,
"network": train_model,
"inferer": val_inferer,
"postprocessing": post_transforms,
"key_val_metric": val_mean_dice,
"additional_metrics": val_additional_metrics,
"val_handlers": val_handlers,
"amp": train_conf_vars.get("amp"),
},
)
# define train handlers
train_checkpoint_loader = Component(
name=CheckpointLoader,
vars={"disabled": train_conf_vars.get("dont_load_ckpt_model")},
args={
"load_path": env_vars.get("MMAR_CKPT"),
"load_dict": {"model": train_model},
},
)
lr_scheduler_handler = Component(
name=LrScheduleHandler, args={"lr_scheduler": lr_scheduler, "print_lr": True}
)
validation_handler = Component(
name=ValidationHandler,
args={
"validator": train_validator,
"epoch_level": True,
"interval": train_conf_vars.get("num_interval_per_valid"),
},
)
train_checkpoint_saver = CheckpointSaverComponent(
name=CheckpointSaver,
vars={"rank": 0},
args={
"save_dir": env_vars.get("MMAR_CKPT_DIR"),
"save_dict": {
"model": train_model,
"optimizer": optimizer,
"lr_scheduler": lr_scheduler,
},
"save_final": True,
"save_interval": 400,
},
)
train_stats_handler = Component(
name=StatsHandler,
vars={"rank": 0},
args={
"tag_name": "train_loss",
"output_transform": "#monai.handlers.from_engine(['loss'], first=True)",
},
)
train_tb_stats_handler = Component(
name=TensorBoardStatsHandler,
vars={"rank": 0},
args={
"log_dir": env_vars.get("MMAR_CKPT_DIR"),
"tag_name": "train_loss",
"output_transform": "#monai.handlers.from_engine(['loss'], first=True)",
},
)
train_handlers = [
train_checkpoint_loader,
lr_scheduler_handler,
validation_handler,
train_checkpoint_saver,
train_stats_handler,
train_tb_stats_handler,
]
# define trainer
trainer = Component(
name=SupervisedTrainer,
args={
"max_epochs": train_conf_vars.get("epochs"),
"device": "cuda",
"train_data_loader": train_dataloader,
"network": train_model,
"loss_function": loss,
"optimizer": optimizer,
"inferer": train_inferer,
"postprocessing": post_transforms,
"key_train_metric": train_acc,
"train_handlers": train_handlers,
"amp": train_conf_vars.get("amp"),
},
)
train_section = {
"loss": loss,
"optimizer": optimizer,
"lr_scheduler": lr_scheduler,
"model": train_model,
"pre_transforms": train_pre_transforms,
"dataset": train_dataset,
"dataloader": train_dataloader,
"inferer": train_inferer,
"handlers": train_handlers,
"post_transforms": post_transforms,
"key_metric": train_acc,
"trainer": trainer,
}
val_section = {
"pre_transforms": val_pre_transforms,
"dataset": val_dataset,
"dataloader": val_dataloader,
"inferer": val_inferer,
"handlers": val_handlers,
"post_transforms": post_transforms,
"key_metric": val_mean_dice,
"additional_metrics": val_additional_metrics,
"evaluator": train_validator,
}
train_conf = create_train_config(
train_section=train_section, val_section=val_section, conf_vars=train_conf_vars
)
# evaluation config
eval_conf_vars = ConfVars(
{
"multi_gpu": False,
"amp": True,
"dont_load_ts_model": False,
"dont_load_ckpt_model": True,
}
)
# define evaluation model
eval_model = [
Component(
vars={
"ts_path": env_vars.get("MMAR_TORCHSCRIPT"),
"disabled": eval_conf_vars.get("dont_load_ts_model"),
}
),
Component(
vars={
"ckpt_path": env_vars.get("MMAR_CKPT"),
"disabled": eval_conf_vars.get("dont_load_ckpt_model"),
}
),
]
# define eval transforms
copy_item_eval = Component(
name=CopyItemsd, args={"keys": "label", "times": 1, "names": "label_foreground"}
)
crop_foreground_eval = Component(
name=CropForegroundd,
args={
"keys": ["image", "label_foreground"],
"source_key": "label",
"margin": 20,
},
)
resize_eval = Component(
name=Resized,
args={
"keys": ["image", "label_foreground"],
"spatial_size": [128, 128, 128],
"mode": ["area", "nearest"],
},
)
add_extreme_points_channel_eval = Component(
name=AddExtremePointsChanneld,
args={"keys": "image", "label_key": "label_foreground", "sigma": 3, "pert": 0},
)
eval_pre_transforms = [
load_image,
ensure_channelfirst,
copy_item_eval,
scale_intensity_range,
crop_foreground_eval,
resize_eval,
add_extreme_points_channel_eval,
to_tensor,
]
invert = Component(
name=Invertd,
args={
"keys": "pred",
"transform": eval_pre_transforms,
"orig_keys": "image",
"meta_keys": "pred_meta_dict",
"nearest_interp": False,
"to_tensor": True,
"device": "cuda",
},
)
eval_post_transforms = [activation, invert, as_discrete]
# define evaluation dataset and dataloader
eval_dataset = Component(
name=Dataset,
vars={
"data_list_file_path": env_vars.get("DATASET_JSON"),
"data_file_base_dir": env_vars.get("DATA_ROOT"),
"data_list_key": env_vars.get("VAL_DATALIST_KEY"),
},
args={"transform": eval_pre_transforms},
)
eval_dataloader = Component(
name=DataLoader,
args={
"dataset": eval_dataset,
"batch_size": 1,
"shuffle": False,
"num_workers": 4,
},
)
# define evaluation handlers
eval_checkpoint_loader = Component(
name=CheckpointLoader,
vars={"disabled": eval_conf_vars.get("dont_load_ckpt_model")},
args={
"load_path": env_vars.get("MMAR_CKPT"),
"load_dict": {"model": eval_model},
},
)
eval_metrics_saver = Component(
name=MetricsSaver,
args={
"save_dir": env_vars.get("MMAR_EVAL_OUTPUT_PATH"),
"metrics": ["val_mean_dice", "val_acc"],
"metric_details": ["val_mean_dice"],
"batch_transform": "#monai.handlers.from_engine(['image_meta_dict'])",
"summary_ops": "*",
"save_rank": 0,
},
)
eval_mean_dice = Component(
name=MeanDice,
vars={"log_label": "val_mean_dice"},
args={
"include_background": True,
"output_transform": "#monai.handlers.from_engine(['pred', 'label'])",
},
)
eval_handlers = [
eval_checkpoint_loader,
val_stats_handler,
eval_metrics_saver,
]
# define evaluator
evaluator = Component(
name=SupervisedEvaluator,
args={
"device": "cuda",
"val_data_loader": eval_dataloader,
"network": eval_model,
"inferer": val_inferer,
"postprocessing": eval_post_transforms,
"key_val_metric": eval_mean_dice,
"additional_metrics": val_additional_metrics,
"val_handlers": eval_handlers,
"amp": eval_conf_vars.get("amp"),
},
)
eval_section = {
"model": eval_model,
"pre_transforms": eval_pre_transforms,
"dataset": eval_dataset,
"dataloader": eval_dataloader,
"inferer": val_inferer,
"handlers": eval_handlers,
"post_transforms": eval_post_transforms,
"key_metric": eval_mean_dice,
"additional_metrics": val_additional_metrics,
"evaluator": evaluator,
}
eval_conf = create_validate_config(eval_section, conf_vars=eval_conf_vars)
# create MMAR
mmar = MMAR(root_path="spleen_annotation", datalist_name="dataset_0.json")
mmar.set_train_config(train_conf)
mmar.set_validate_config(eval_conf)
mmar.set_commands_and_resources_from_template_dir(template_dir="./template")
mmar.set_environment(env=env_vars.var_dict)
mmar.set_datalist(datalist=DATALIST)
# build MMAR
mmar.build()
# test load MMAR
test_mmar = MMAR()
test_mmar.load(root_path="spleen_annotation")
print(test_mmar.get_train_config().get_config())
print(test_mmar.get_validate_config().get_config())
if __name__ == "__main__":
main()