Create spleen annotation MMAR with MMAR API

Copy
Copied!
            

from ignite.metrics import Accuracy from medl.tools.mmar_creator.component import ( CheckpointSaverComponent, Component, OptimizerComponent, ) from medl.tools.mmar_creator.mmar import MMAR from medl.tools.mmar_creator.utils import ( ConfVars, create_train_config, create_validate_config, ) from monai.data import CacheDataset, DataLoader, Dataset from monai.engines import SupervisedEvaluator, SupervisedTrainer from monai.handlers import ( CheckpointLoader, CheckpointSaver, LrScheduleHandler, MeanDice, MetricsSaver, StatsHandler, TensorBoardStatsHandler, ValidationHandler, ) from monai.inferers import SimpleInferer from monai.losses import DiceLoss from monai.networks.nets import UNet from monai.transforms import ( Activationsd, AddExtremePointsChanneld, AsDiscreted, CopyItemsd, CropForegroundd, EnsureChannelFirstd, Invertd, LoadImaged, RandFlipd, RandRotate90d, RandShiftIntensityd, RandZoomd, Resized, ScaleIntensityRanged, ToTensord, ) from torch.optim import Adam from torch.optim.lr_scheduler import StepLR DATALIST = { "training": [ { "image": "imagesTr/spleen_29.nii.gz", "label": "labelsTr/spleen_29.nii.gz", }, { "image": "imagesTr/spleen_46.nii.gz", "label": "labelsTr/spleen_46.nii.gz", }, { "image": "imagesTr/spleen_25.nii.gz", "label": "labelsTr/spleen_25.nii.gz", }, { "image": "imagesTr/spleen_13.nii.gz", "label": "labelsTr/spleen_13.nii.gz", }, { "image": "imagesTr/spleen_62.nii.gz", "label": "labelsTr/spleen_62.nii.gz", }, { "image": "imagesTr/spleen_27.nii.gz", "label": "labelsTr/spleen_27.nii.gz", }, { "image": "imagesTr/spleen_44.nii.gz", "label": "labelsTr/spleen_44.nii.gz", }, { "image": "imagesTr/spleen_56.nii.gz", "label": "labelsTr/spleen_56.nii.gz", }, { "image": "imagesTr/spleen_60.nii.gz", "label": "labelsTr/spleen_60.nii.gz", }, { "image": "imagesTr/spleen_2.nii.gz", "label": "labelsTr/spleen_2.nii.gz", }, { "image": "imagesTr/spleen_53.nii.gz", "label": "labelsTr/spleen_53.nii.gz", }, { "image": "imagesTr/spleen_41.nii.gz", "label": "labelsTr/spleen_41.nii.gz", }, { "image": "imagesTr/spleen_22.nii.gz", "label": "labelsTr/spleen_22.nii.gz", }, { "image": "imagesTr/spleen_14.nii.gz", "label": "labelsTr/spleen_14.nii.gz", }, { "image": "imagesTr/spleen_18.nii.gz", "label": "labelsTr/spleen_18.nii.gz", }, { "image": "imagesTr/spleen_20.nii.gz", "label": "labelsTr/spleen_20.nii.gz", }, { "image": "imagesTr/spleen_32.nii.gz", "label": "labelsTr/spleen_32.nii.gz", }, { "image": "imagesTr/spleen_16.nii.gz", "label": "labelsTr/spleen_16.nii.gz", }, { "image": "imagesTr/spleen_12.nii.gz", "label": "labelsTr/spleen_12.nii.gz", }, { "image": "imagesTr/spleen_63.nii.gz", "label": "labelsTr/spleen_63.nii.gz", }, { "image": "imagesTr/spleen_28.nii.gz", "label": "labelsTr/spleen_28.nii.gz", }, { "image": "imagesTr/spleen_24.nii.gz", "label": "labelsTr/spleen_24.nii.gz", }, { "image": "imagesTr/spleen_59.nii.gz", "label": "labelsTr/spleen_59.nii.gz", }, { "image": "imagesTr/spleen_47.nii.gz", "label": "labelsTr/spleen_47.nii.gz", }, { "image": "imagesTr/spleen_8.nii.gz", "label": "labelsTr/spleen_8.nii.gz", }, { "image": "imagesTr/spleen_6.nii.gz", "label": "labelsTr/spleen_6.nii.gz", }, { "image": "imagesTr/spleen_61.nii.gz", "label": "labelsTr/spleen_61.nii.gz", }, { "image": "imagesTr/spleen_10.nii.gz", "label": "labelsTr/spleen_10.nii.gz", }, { "image": "imagesTr/spleen_38.nii.gz", "label": "labelsTr/spleen_38.nii.gz", }, { "image": "imagesTr/spleen_45.nii.gz", "label": "labelsTr/spleen_45.nii.gz", }, { "image": "imagesTr/spleen_26.nii.gz", "label": "labelsTr/spleen_26.nii.gz", }, { "image": "imagesTr/spleen_49.nii.gz", "label": "labelsTr/spleen_49.nii.gz", }, ], "validation": [ { "image": "imagesTr/spleen_19.nii.gz", "label": "labelsTr/spleen_19.nii.gz", }, { "image": "imagesTr/spleen_31.nii.gz", "label": "labelsTr/spleen_31.nii.gz", }, { "image": "imagesTr/spleen_52.nii.gz", "label": "labelsTr/spleen_52.nii.gz", }, { "image": "imagesTr/spleen_40.nii.gz", "label": "labelsTr/spleen_40.nii.gz", }, { "image": "imagesTr/spleen_3.nii.gz", "label": "labelsTr/spleen_3.nii.gz", }, { "image": "imagesTr/spleen_17.nii.gz", "label": "labelsTr/spleen_17.nii.gz", }, { "image": "imagesTr/spleen_21.nii.gz", "label": "labelsTr/spleen_21.nii.gz", }, { "image": "imagesTr/spleen_33.nii.gz", "label": "labelsTr/spleen_33.nii.gz", }, { "image": "imagesTr/spleen_9.nii.gz", "label": "labelsTr/spleen_9.nii.gz", }, ], } def main(): # env variables env_vars = ConfVars( { "CLARA_TRAIN_VERSION": "4.1.0", "DATA_ROOT": "/dataset", "DATASET_JSON": "config/dataset_0.json", "PROCESSING_TASK": "segmentation", "TRAIN_DATALIST_KEY": "training", "VAL_DATALIST_KEY": "validation", "INFER_DATALIST_KEY": "test", "MMAR_EVAL_OUTPUT_PATH": "eval", "MMAR_CKPT_DIR": "models", "MMAR_CKPT": "models/model.pt", "MMAR_TORCHSCRIPT": "models/model.ts", "INPUT_CHANNELS": 2, "OUTPUT_CHANNELS": 2, } ) # train configuration variables train_conf_vars = ConfVars( { "epochs": 2000, "learning_rate": 2e-4, "num_interval_per_valid": 10, "multi_gpu": False, "amp": True, "determinism": {"random_seed": 0}, "cudnn_benchmark": False, "dont_load_ckpt_model": True, } ) # define train components loss = Component(name=DiceLoss, args={"to_onehot_y": True, "softmax": True}) optimizer = OptimizerComponent( name=Adam, args={ "lr": train_conf_vars.get("learning_rate"), }, ) train_model = Component( name=UNet, args={ "spatial_dims": 3, "in_channels": env_vars.get("INPUT_CHANNELS"), "out_channels": env_vars.get("OUTPUT_CHANNELS"), "channels": [16, 32, 64, 128, 256], "strides": [2, 2, 2, 2], "num_res_units": 2, "norm": "batch", }, ) lr_scheduler = Component( name=StepLR, args={"optimizer": optimizer, "step_size": 5000, "gamma": 0.1} ) # define pre_transforms load_image = Component(name=LoadImaged, args={"keys": ["image", "label"]}) ensure_channelfirst = Component( name=EnsureChannelFirstd, args={"keys": ["image", "label"]} ) scale_intensity_range = Component( name=ScaleIntensityRanged, args={ "keys": "image", "a_min": -57, "a_max": 164, "b_min": 0.0, "b_max": 1.0, "clip": True, }, ) crop_foreground = Component( name=CropForegroundd, args={"keys": ["image", "label"], "source_key": "label", "margin": 20}, ) rand_shift_intensity = Component( name=RandShiftIntensityd, args={"keys": "image", "offsets": 0.2, "prob": 0.5} ) rand_flip = Component( name=RandFlipd, args={"keys": ["image", "label"], "prob": 0.5, "spatial_axis": 0}, ) rand_rotate = Component( name=RandRotate90d, args={"keys": ["image", "label"], "spatial_axes": [1, 2], "prob": 0.5}, ) rand_zoom = Component( name=RandZoomd, args={ "keys": ["image", "label"], "min_zoom": 0.8, "max_zoom": 1.2, "mode": ["area", "nearest"], "prob": 1.0, }, ) resize = Component( name=Resized, args={ "keys": ["image", "label"], "spatial_size": [128, 128, 128], "mode": ["area", "nearest"], }, ) add_extreme_points_channel = Component( name=AddExtremePointsChanneld, args={"keys": "image", "label_key": "label", "sigma": 3, "pert": 3}, ) add_extreme_points_channel_val = Component( name=AddExtremePointsChanneld, args={"keys": "image", "label_key": "label", "sigma": 3, "pert": 0}, ) to_tensor = Component(name=ToTensord, args={"keys": ["image", "label"]}) train_pre_transforms = [ load_image, ensure_channelfirst, scale_intensity_range, crop_foreground, rand_shift_intensity, rand_flip, rand_rotate, rand_zoom, resize, add_extreme_points_channel, to_tensor, ] val_pre_transforms = [ load_image, ensure_channelfirst, scale_intensity_range, crop_foreground, resize, add_extreme_points_channel_val, to_tensor, ] # define train dataset and dataloader train_dataset = Component( name=CacheDataset, vars={ "data_list_file_path": env_vars.get("DATASET_JSON"), "data_file_base_dir": env_vars.get("DATA_ROOT"), "data_list_key": env_vars.get("TRAIN_DATALIST_KEY"), }, args={ "transform": train_pre_transforms, "cache_num": 32, "cache_rate": 1.0, "num_workers": 2, }, ) train_dataloader = Component( name=DataLoader, args={ "dataset": train_dataset, "batch_size": 12, "shuffle": True, "num_workers": 2, }, ) # define train inferer train_inferer = Component(name=SimpleInferer) # define post transforms activation = Component(name=Activationsd, args={"keys": "pred", "softmax": True}) as_discrete = Component( name=AsDiscreted, args={"keys": ["pred", "label"], "argmax": [True, False], "to_onehot": 2}, ) post_transforms = [activation, as_discrete] # define train metrics train_acc = Component( name=Accuracy, vars={"log_label": "train_acc"}, args={"output_transform": "#monai.handlers.from_engine(['pred', 'label'])"}, ) # define validation dataset and dataloader val_dataset = Component( name=CacheDataset, vars={ "data_list_file_path": env_vars.get("DATASET_JSON"), "data_file_base_dir": env_vars.get("DATA_ROOT"), "data_list_key": env_vars.get("VAL_DATALIST_KEY"), }, args={ "transform": val_pre_transforms, "cache_num": 9, "cache_rate": 1.0, "num_workers": 2, }, ) val_dataloader = Component( name=DataLoader, args={ "dataset": val_dataset, "batch_size": 12, "shuffle": False, "num_workers": 2, }, ) # define validation inferer val_inferer = Component(name=SimpleInferer) # define validation handlers val_stats_handler = Component( name=StatsHandler, vars={"rank": 0}, args={"output_transform": "lambda x: None"} ) val_tb_stats_handler = Component( name=TensorBoardStatsHandler, vars={"rank": 0}, args={ "log_dir": env_vars.get("MMAR_CKPT_DIR"), "output_transform": "lambda x: None", }, ) val_checkpoint_saver = CheckpointSaverComponent( name=CheckpointSaver, vars={"rank": 0}, args={ "save_dir": env_vars.get("MMAR_CKPT_DIR"), "save_dict": {"model": train_model}, "save_final": True, "save_key_metric": True, }, ) val_handlers = [ val_stats_handler, val_tb_stats_handler, val_checkpoint_saver, ] # define validation metrics val_mean_dice = Component( name=MeanDice, vars={"log_label": "val_mean_dice"}, args={ "include_background": False, "output_transform": "#monai.handlers.from_engine(['pred', 'label'])", }, ) val_acc = Component( name=Accuracy, vars={"log_label": "val_acc"}, args={"output_transform": "#monai.handlers.from_engine(['pred', 'label'])"}, ) val_additional_metrics = [val_acc] # define train validator train_validator = Component( name=SupervisedEvaluator, args={ "device": "cuda", "val_data_loader": val_dataloader, "network": train_model, "inferer": val_inferer, "postprocessing": post_transforms, "key_val_metric": val_mean_dice, "additional_metrics": val_additional_metrics, "val_handlers": val_handlers, "amp": train_conf_vars.get("amp"), }, ) # define train handlers train_checkpoint_loader = Component( name=CheckpointLoader, vars={"disabled": train_conf_vars.get("dont_load_ckpt_model")}, args={ "load_path": env_vars.get("MMAR_CKPT"), "load_dict": {"model": train_model}, }, ) lr_scheduler_handler = Component( name=LrScheduleHandler, args={"lr_scheduler": lr_scheduler, "print_lr": True} ) validation_handler = Component( name=ValidationHandler, args={ "validator": train_validator, "epoch_level": True, "interval": train_conf_vars.get("num_interval_per_valid"), }, ) train_checkpoint_saver = CheckpointSaverComponent( name=CheckpointSaver, vars={"rank": 0}, args={ "save_dir": env_vars.get("MMAR_CKPT_DIR"), "save_dict": { "model": train_model, "optimizer": optimizer, "lr_scheduler": lr_scheduler, }, "save_final": True, "save_interval": 400, }, ) train_stats_handler = Component( name=StatsHandler, vars={"rank": 0}, args={ "tag_name": "train_loss", "output_transform": "#monai.handlers.from_engine(['loss'], first=True)", }, ) train_tb_stats_handler = Component( name=TensorBoardStatsHandler, vars={"rank": 0}, args={ "log_dir": env_vars.get("MMAR_CKPT_DIR"), "tag_name": "train_loss", "output_transform": "#monai.handlers.from_engine(['loss'], first=True)", }, ) train_handlers = [ train_checkpoint_loader, lr_scheduler_handler, validation_handler, train_checkpoint_saver, train_stats_handler, train_tb_stats_handler, ] # define trainer trainer = Component( name=SupervisedTrainer, args={ "max_epochs": train_conf_vars.get("epochs"), "device": "cuda", "train_data_loader": train_dataloader, "network": train_model, "loss_function": loss, "optimizer": optimizer, "inferer": train_inferer, "postprocessing": post_transforms, "key_train_metric": train_acc, "train_handlers": train_handlers, "amp": train_conf_vars.get("amp"), }, ) train_section = { "loss": loss, "optimizer": optimizer, "lr_scheduler": lr_scheduler, "model": train_model, "pre_transforms": train_pre_transforms, "dataset": train_dataset, "dataloader": train_dataloader, "inferer": train_inferer, "handlers": train_handlers, "post_transforms": post_transforms, "key_metric": train_acc, "trainer": trainer, } val_section = { "pre_transforms": val_pre_transforms, "dataset": val_dataset, "dataloader": val_dataloader, "inferer": val_inferer, "handlers": val_handlers, "post_transforms": post_transforms, "key_metric": val_mean_dice, "additional_metrics": val_additional_metrics, "evaluator": train_validator, } train_conf = create_train_config( train_section=train_section, val_section=val_section, conf_vars=train_conf_vars ) # evaluation config eval_conf_vars = ConfVars( { "multi_gpu": False, "amp": True, "dont_load_ts_model": False, "dont_load_ckpt_model": True, } ) # define evaluation model eval_model = [ Component( vars={ "ts_path": env_vars.get("MMAR_TORCHSCRIPT"), "disabled": eval_conf_vars.get("dont_load_ts_model"), } ), Component( vars={ "ckpt_path": env_vars.get("MMAR_CKPT"), "disabled": eval_conf_vars.get("dont_load_ckpt_model"), } ), ] # define eval transforms copy_item_eval = Component( name=CopyItemsd, args={"keys": "label", "times": 1, "names": "label_foreground"} ) crop_foreground_eval = Component( name=CropForegroundd, args={ "keys": ["image", "label_foreground"], "source_key": "label", "margin": 20, }, ) resize_eval = Component( name=Resized, args={ "keys": ["image", "label_foreground"], "spatial_size": [128, 128, 128], "mode": ["area", "nearest"], }, ) add_extreme_points_channel_eval = Component( name=AddExtremePointsChanneld, args={"keys": "image", "label_key": "label_foreground", "sigma": 3, "pert": 0}, ) eval_pre_transforms = [ load_image, ensure_channelfirst, copy_item_eval, scale_intensity_range, crop_foreground_eval, resize_eval, add_extreme_points_channel_eval, to_tensor, ] invert = Component( name=Invertd, args={ "keys": "pred", "transform": eval_pre_transforms, "orig_keys": "image", "meta_keys": "pred_meta_dict", "nearest_interp": False, "to_tensor": True, "device": "cuda", }, ) eval_post_transforms = [activation, invert, as_discrete] # define evaluation dataset and dataloader eval_dataset = Component( name=Dataset, vars={ "data_list_file_path": env_vars.get("DATASET_JSON"), "data_file_base_dir": env_vars.get("DATA_ROOT"), "data_list_key": env_vars.get("VAL_DATALIST_KEY"), }, args={"transform": eval_pre_transforms}, ) eval_dataloader = Component( name=DataLoader, args={ "dataset": eval_dataset, "batch_size": 1, "shuffle": False, "num_workers": 4, }, ) # define evaluation handlers eval_checkpoint_loader = Component( name=CheckpointLoader, vars={"disabled": eval_conf_vars.get("dont_load_ckpt_model")}, args={ "load_path": env_vars.get("MMAR_CKPT"), "load_dict": {"model": eval_model}, }, ) eval_metrics_saver = Component( name=MetricsSaver, args={ "save_dir": env_vars.get("MMAR_EVAL_OUTPUT_PATH"), "metrics": ["val_mean_dice", "val_acc"], "metric_details": ["val_mean_dice"], "batch_transform": "#monai.handlers.from_engine(['image_meta_dict'])", "summary_ops": "*", "save_rank": 0, }, ) eval_mean_dice = Component( name=MeanDice, vars={"log_label": "val_mean_dice"}, args={ "include_background": True, "output_transform": "#monai.handlers.from_engine(['pred', 'label'])", }, ) eval_handlers = [ eval_checkpoint_loader, val_stats_handler, eval_metrics_saver, ] # define evaluator evaluator = Component( name=SupervisedEvaluator, args={ "device": "cuda", "val_data_loader": eval_dataloader, "network": eval_model, "inferer": val_inferer, "postprocessing": eval_post_transforms, "key_val_metric": eval_mean_dice, "additional_metrics": val_additional_metrics, "val_handlers": eval_handlers, "amp": eval_conf_vars.get("amp"), }, ) eval_section = { "model": eval_model, "pre_transforms": eval_pre_transforms, "dataset": eval_dataset, "dataloader": eval_dataloader, "inferer": val_inferer, "handlers": eval_handlers, "post_transforms": eval_post_transforms, "key_metric": eval_mean_dice, "additional_metrics": val_additional_metrics, "evaluator": evaluator, } eval_conf = create_validate_config(eval_section, conf_vars=eval_conf_vars) # create MMAR mmar = MMAR(root_path="spleen_annotation", datalist_name="dataset_0.json") mmar.set_train_config(train_conf) mmar.set_validate_config(eval_conf) mmar.set_commands_and_resources_from_template_dir(template_dir="./template") mmar.set_environment(env=env_vars.var_dict) mmar.set_datalist(datalist=DATALIST) # build MMAR mmar.build() # test load MMAR test_mmar = MMAR() test_mmar.load(root_path="spleen_annotation") print(test_mmar.get_train_config().get_config()) print(test_mmar.get_validate_config().get_config()) if __name__ == "__main__": main()

© Copyright 2021, NVIDIA. Last updated on Feb 2, 2023.