Create spleen segmentation MMAR with MMAR API

1.0
Copy
Copied!
            

from medl.tools.mmar_creator.train_config import TrainConfig from medl.tools.mmar_creator.validate_config import ValidateConfig from medl.tools.mmar_creator.component import Component from medl.tools.mmar_creator.mmar import MMAR from torch.optim import Adam from torch.optim.lr_scheduler import StepLR from ignite.metrics import Accuracy from monai.data import CacheDataset, Dataset, DataLoader from monai.engines import SupervisedTrainer, SupervisedEvaluator from monai.handlers import ( LrScheduleHandler, ValidationHandler, CheckpointSaver, StatsHandler, TensorBoardStatsHandler, MeanDice, CheckpointLoader, MetricsSaver, ) from monai.inferers import SimpleInferer, SlidingWindowInferer from monai.losses import DiceLoss from monai.networks.nets import UNet from monai.transforms import ( Activationsd, AsDiscreted, LoadImaged, EnsureChannelFirstd, ScaleIntensityRanged, CropForegroundd, RandCropByPosNegLabeld, RandShiftIntensityd, ToTensord, Invertd, SaveImaged, ) def main(): # define train components loss = Component(name=DiceLoss, args={"to_onehot_y": True, "softmax": True}) optimizer = Component(name=Adam, args={"params": "#@model.parameters()", "lr": "{learning_rate}"}) train_model = Component(name=UNet, args={"spatial_dims": 3, "in_channels": "{INPUT_CHANNELS}", "out_channels": "{OUTPUT_CHANNELS}", "channels": [16, 32, 64, 128, 256], "strides": [2, 2, 2, 2], "num_res_units": 2, "norm": "batch"}) lr_scheduler = Component(name=StepLR, args={"optimizer": "@optimizer", "step_size": 5000, "gamma": 0.1}) # define pre_transforms load_image = Component(name=LoadImaged, args={"keys": ["image", "label"]}) ensure_channelfirst = Component(name=EnsureChannelFirstd, args={"keys": ["image", "label"]}) scale_intensity_range = Component(name=ScaleIntensityRanged, args={"keys": "image", "a_min": -57, "a_max": 164, "b_min": 0.0, "b_max": 1.0, "clip": True}) crop_foreground = Component(name=CropForegroundd, args={"keys": ["image", "label"], "source_key": "image"}) crop_foreground_eval = Component(name=CropForegroundd, args={"keys": "image", "source_key": "image"}) rand_crop_by_posneg_label = Component(name=RandCropByPosNegLabeld, args={"keys": ["image", "label"], "label_key": "label", "spatial_size": [96, 96, 96], "pos": 1, "neg": 1, "num_samples": 4, "image_key": "image", "image_threshold": 0}) rand_shift_intensity = Component(name=RandShiftIntensityd, args={"keys": "image", "offsets": 0.1, "prob": 0.5}) to_tensor = Component(name=ToTensord, args={"keys": ["image", "label"]}) # define train dataset and dataloader train_dataset = Component( name=CacheDataset, vars={"data_list_file_path": "{DATASET_JSON}", "data_file_base_dir": "{DATA_ROOT}", "data_list_key": "{TRAIN_DATALIST_KEY}"}, args={"transform": "@pre_transforms", "cache_num": 32, "cache_rate": 1.0, "num_workers": 4}, ) train_dataloader = Component(name=DataLoader, args={"dataset": "@dataset", "batch_size": 2, "shuffle": True, "num_workers": 4}) # define train inferer train_inferer = Component(name=SimpleInferer) # define train handlers lr_scheduler_handler = Component(name=LrScheduleHandler, args={"lr_scheduler": "@lr_scheduler", "print_lr": True}) validation_handler = Component(name=ValidationHandler, args={"validator": "@evaluator", "epoch_level": True, "interval": "{num_interval_per_valid}"}) train_checkpoint_saver = Component(name=CheckpointSaver, vars={"rank": 0}, args={"save_dir": "{MMAR_CKPT_DIR}", "save_dict": {"model": "@model", "optimizer": "@optimizer", "lr_scheduler": "@lr_scheduler", "train_conf": "@conf"}, "save_final": True, "save_interval": 400}) train_stats_handler = Component(name=StatsHandler, vars={"rank": 0}, args={"tag_name": "train_loss", "output_transform": "#monai.handlers.from_engine(['loss'], first=True)"}) train_tb_stats_handler = Component(name=TensorBoardStatsHandler, vars={"rank": 0}, args={"log_dir": "{MMAR_CKPT_DIR}", "tag_name": "train_loss", "output_transform": "#monai.handlers.from_engine(['loss'], first=True)"}) # define post transforms activation = Component(name=Activationsd, args={"keys": "pred", "softmax": True}) invert = Component(name=Invertd, args={"keys": "pred", "transform": "@pre_transforms", "orig_keys": "image", "meta_keys": "pred_meta_dict", "nearest_interp": False, "to_tensor": True, "device": "cuda"}) as_discrete = Component(name=AsDiscreted, args={"keys": ["pred", "label"], "argmax": [True, False], "to_onehot": 2}) save_image = Component(name=SaveImaged, args={"keys": "pred", "meta_keys": "pred_meta_dict", "output_dir": "{MMAR_EVAL_OUTPUT_PATH}", "resample": False, "squeeze_end_dims": True}) # define train metrics train_acc = Component(name=Accuracy, vars={"log_label": "train_acc"}, args={"output_transform": "#monai.handlers.from_engine(['pred', 'label'])"}) # define trainer trainer = Component( name=SupervisedTrainer, args={ "max_epochs": "{epochs}", "device": "cuda", "train_data_loader": "@dataloader", "network": "@model", "loss_function": "@loss", "optimizer": "@optimizer", "inferer": "@inferer", "postprocessing": "@post_transforms", "key_train_metric": "@key_metric", "train_handlers": "@handlers", "amp": "{amp}", } ) # define validation dataset and dataloader val_dataset = Component( name=CacheDataset, vars={"data_list_file_path": "{DATASET_JSON}", "data_file_base_dir": "{DATA_ROOT}", "data_list_key": "{VAL_DATALIST_KEY}"}, args={"transform": "@pre_transforms", "cache_num": 9, "cache_rate": 1.0, "num_workers": 4}, ) val_dataloader = Component(name=DataLoader, args={"dataset": "@dataset", "batch_size": 1, "shuffle": False, "num_workers": 4}) # define validation inferer val_inferer = Component(name=SlidingWindowInferer, args={"roi_size": [160, 160, 160], "sw_batch_size": 4, "overlap": 0.5}) # define validation handlers val_stats_handler = Component(name=StatsHandler, vars={"rank": 0}, args={"output_transform": "lambda x: None"}) val_tb_stats_handler = Component(name=TensorBoardStatsHandler, vars={"rank": 0}, args={"log_dir": "{MMAR_CKPT_DIR}", "output_transform": "lambda x: None"}) val_checkpoint_saver = Component(name=CheckpointSaver, vars={"rank": 0}, args={"save_dir": "{MMAR_CKPT_DIR}", "save_dict": {"model": "@model", "train_conf": "@conf"}, "save_key_metric": True}) # define validation metrics val_mean_dice = Component(name=MeanDice, vars={"log_label": "val_mean_dice", "is_key_metric": True}, args={"include_background": False, "output_transform": "#monai.handlers.from_engine(['pred', 'label'])"}) val_acc = Component(name=Accuracy, vars={"log_label": "val_acc"}, args={"output_transform": "#monai.handlers.from_engine(['pred', 'label'])"}) # define evaluator evaluator = Component( name=SupervisedEvaluator, args={ "device": "cuda", "val_data_loader": "@dataloader", "network": "@model", "inferer": "@inferer", "postprocessing": "@post_transforms", "key_val_metric": "@key_metric", "additional_metrics": "@additional_metrics", "val_handlers": "@handlers", "amp": "{amp}", } ) # define evaluation dataset eval_dataset = Component( name=Dataset, vars={"data_list_file_path": "{DATASET_JSON}", "data_file_base_dir": "{DATA_ROOT}", "data_list_key": "{VAL_DATALIST_KEY}"}, args={"transform": "@pre_transforms"}, ) # define evaluation handlers eval_checkpoint_loader = Component(name=CheckpointLoader, vars={"disabled": "{dont_load_ckpt_model}"}, args={"load_path": "{MMAR_CKPT}", "load_dict": {"model": "@model"}}) eval_metrics_saver = Component(name=MetricsSaver, args={"save_dir": "{MMAR_EVAL_OUTPUT_PATH}", "metrics": ["val_mean_dice", "val_acc"], "metric_details": ["val_mean_dice"], "batch_transform": "#monai.handlers.from_engine(['image_meta_dict'])", "summary_ops": "*", "save_rank": 0}) # create train config based on above components train_conf = TrainConfig() # set variables in config train_conf.add_vars(epochs=1260, num_interval_per_valid=20, learning_rate=2e-4, multi_gpu=False, amp=True, cudnn_benchmark=False) train_conf.add_loss(loss) train_conf.add_optimizer(optimizer) train_conf.add_model(train_model) train_conf.add_lr_scheduler(lr_scheduler) train_conf.add_train_pre_transform([ load_image, ensure_channelfirst, scale_intensity_range, crop_foreground, rand_crop_by_posneg_label, rand_shift_intensity, to_tensor, ]) train_conf.add_train_dataset(train_dataset) train_conf.add_train_dataloader(train_dataloader) train_conf.add_train_inferer(train_inferer) train_conf.add_train_handler([ lr_scheduler_handler, validation_handler, train_checkpoint_saver, train_stats_handler, train_tb_stats_handler, ]) train_conf.add_train_post_transform([activation, as_discrete]) train_conf.add_train_key_metric(train_acc) train_conf.add_trainer(trainer) train_conf.add_val_pre_transform([ load_image, ensure_channelfirst, scale_intensity_range, crop_foreground, to_tensor, ]) train_conf.add_val_dataset(val_dataset) train_conf.add_val_dataloader(val_dataloader) train_conf.add_val_inferer(val_inferer) train_conf.add_val_handler([ val_stats_handler, val_tb_stats_handler, val_checkpoint_saver, ]) train_conf.add_val_post_transform([activation, as_discrete]) train_conf.add_val_key_metric(val_mean_dice) train_conf.add_val_additional_metric([val_acc]) train_conf.add_evaluator(evaluator) # create validation config based on above components eval_conf = ValidateConfig() # set variables in the config eval_conf.add_vars(multi_gpu=False, amp=True, dont_load_ts_model=False, dont_load_ckpt_model=True) eval_conf.add_model([ Component(vars={"ts_path": "{MMAR_TORCHSCRIPT}", "disabled": "{dont_load_ts_model}"}), Component(vars={"ckpt_path": "{MMAR_CKPT}", "disabled": "{dont_load_ckpt_model}"}), ]) eval_conf.add_pre_transform([ load_image, ensure_channelfirst, scale_intensity_range, crop_foreground_eval, to_tensor, ]) eval_conf.add_dataset(eval_dataset) eval_conf.add_dataloader(val_dataloader) eval_conf.add_inferer(val_inferer) eval_conf.add_handler([ eval_checkpoint_loader, val_stats_handler, eval_metrics_saver, ]) eval_conf.add_post_transform([activation, invert, as_discrete, save_image]) eval_conf.add_key_metric(val_mean_dice) eval_conf.add_additional_metric([val_acc]) eval_conf.add_evaluator(evaluator) # create MMAR mmar = MMAR(root_path="spleen_segmentation", datalist_name="dataset_0.json") mmar.set_train_config(train_conf) mmar.set_validate_config(eval_conf) mmar.set_resources({"log.config": "[loggers]\nkeys=root,modelLogger\n[handlers]\nkeys=consoleHandler\n[formatters]\nkeys=fullFormatter\n[logger_root]\nlevel=INFO\nhandlers=consoleHandler\n[logger_modelLogger]\nlevel=DEBUG\nhandlers=consoleHandler\nqualname=modelLogger\npropagate=0\n[handler_consoleHandler]\nclass=StreamHandler\nlevel=DEBUG\nformatter=fullFormatter\nargs=(sys.stdout,)\n[formatter_fullFormatter]\nformat=%(asctime)s-%(name)s-%(levelname)s-%(message)s"}) mmar.set_environment(env={ "CLARA_TRAIN_VERSION": "4.1.0", "DATA_ROOT": "/workspace/data/Task09_Spleen_nii", "DATASET_JSON": "config/dataset_0.json", "PROCESSING_TASK": "segmentation", "TRAIN_DATALIST_KEY": "training", "VAL_DATALIST_KEY": "validation", "INFER_DATALIST_KEY": "test", "MMAR_EVAL_OUTPUT_PATH": "eval", "MMAR_CKPT_DIR": "models", "MMAR_CKPT": "models/model.pt", "MMAR_TORCHSCRIPT": "models/model.ts", "INPUT_CHANNELS": 1, "OUTPUT_CHANNELS": 2 }) mmar.set_datalist( datalist={ "training": [ {"image": "imagesTr/spleen_29.nii", "label": "labelsTr/spleen_29.nii"}, {"image": "imagesTr/spleen_46.nii", "label": "labelsTr/spleen_46.nii"}, {"image": "imagesTr/spleen_25.nii", "label": "labelsTr/spleen_25.nii"}, {"image": "imagesTr/spleen_13.nii", "label": "labelsTr/spleen_13.nii"}, {"image": "imagesTr/spleen_62.nii", "label": "labelsTr/spleen_62.nii"}, {"image": "imagesTr/spleen_27.nii", "label": "labelsTr/spleen_27.nii"}, {"image": "imagesTr/spleen_44.nii", "label": "labelsTr/spleen_44.nii"}, {"image": "imagesTr/spleen_56.nii", "label": "labelsTr/spleen_56.nii"}, {"image": "imagesTr/spleen_60.nii", "label": "labelsTr/spleen_60.nii"}, {"image": "imagesTr/spleen_2.nii", "label": "labelsTr/spleen_2.nii"}, {"image": "imagesTr/spleen_53.nii", "label": "labelsTr/spleen_53.nii"}, {"image": "imagesTr/spleen_41.nii", "label": "labelsTr/spleen_41.nii"}, {"image": "imagesTr/spleen_22.nii", "label": "labelsTr/spleen_22.nii"}, {"image": "imagesTr/spleen_14.nii", "label": "labelsTr/spleen_14.nii"}, {"image": "imagesTr/spleen_18.nii", "label": "labelsTr/spleen_18.nii"}, {"image": "imagesTr/spleen_20.nii", "label": "labelsTr/spleen_20.nii"}, {"image": "imagesTr/spleen_32.nii", "label": "labelsTr/spleen_32.nii"}, {"image": "imagesTr/spleen_16.nii", "label": "labelsTr/spleen_16.nii"}, {"image": "imagesTr/spleen_12.nii", "label": "labelsTr/spleen_12.nii"}, {"image": "imagesTr/spleen_63.nii", "label": "labelsTr/spleen_63.nii"}, {"image": "imagesTr/spleen_28.nii", "label": "labelsTr/spleen_28.nii"}, {"image": "imagesTr/spleen_24.nii", "label": "labelsTr/spleen_24.nii"}, {"image": "imagesTr/spleen_59.nii", "label": "labelsTr/spleen_59.nii"}, {"image": "imagesTr/spleen_47.nii", "label": "labelsTr/spleen_47.nii"}, {"image": "imagesTr/spleen_8.nii", "label": "labelsTr/spleen_8.nii"}, {"image": "imagesTr/spleen_6.nii", "label": "labelsTr/spleen_6.nii"}, {"image": "imagesTr/spleen_61.nii", "label": "labelsTr/spleen_61.nii"}, {"image": "imagesTr/spleen_10.nii", "label": "labelsTr/spleen_10.nii"}, {"image": "imagesTr/spleen_38.nii", "label": "labelsTr/spleen_38.nii"}, {"image": "imagesTr/spleen_45.nii", "label": "labelsTr/spleen_45.nii"}, {"image": "imagesTr/spleen_26.nii", "label": "labelsTr/spleen_26.nii"}, {"image": "imagesTr/spleen_49.nii", "label": "labelsTr/spleen_49.nii"}, ], "validation": [ {"image": "imagesTr/spleen_19.nii", "label": "labelsTr/spleen_19.nii"}, {"image": "imagesTr/spleen_31.nii", "label": "labelsTr/spleen_31.nii"}, {"image": "imagesTr/spleen_52.nii", "label": "labelsTr/spleen_52.nii"}, {"image": "imagesTr/spleen_40.nii", "label": "labelsTr/spleen_40.nii"}, {"image": "imagesTr/spleen_3.nii", "label": "labelsTr/spleen_3.nii"}, {"image": "imagesTr/spleen_17.nii", "label": "labelsTr/spleen_17.nii"}, {"image": "imagesTr/spleen_21.nii", "label": "labelsTr/spleen_21.nii"}, {"image": "imagesTr/spleen_33.nii", "label": "labelsTr/spleen_33.nii"}, {"image": "imagesTr/spleen_9.nii", "label": "labelsTr/spleen_9.nii"}, ], } ) mmar.set_commands({ "set_env.sh": "#!/usr/bin/env bash\nexport PYTHONPATH=\"$PYTHONPATH:/opt/nvidia\"\nDIR=\"$( cd\"$( dirname\"${BASH_SOURCE[0]}\")\">/dev/null 2>&1 && pwd )\"\nexport MMAR_ROOT=${DIR}/..", "train.sh": "#!/usr/bin/env bash\nmy_dir=\"$(dirname\"$0\")\"\n. $my_dir/set_env.sh\necho\"MMAR_ROOT set to $MMAR_ROOT\"\nadditional_options=\"$*\"\nCONFIG_FILE=config/config_train.json\nENVIRONMENT_FILE=config/environment.json\npython3 -u -m medl.apps.train -m $MMAR_ROOT -c $CONFIG_FILE -e $ENVIRONMENT_FILE --write_train_stats --set print_conf=True epochs=1260 learning_rate=0.0002 num_interval_per_valid=20 multi_gpu=False cudnn_benchmark=False dont_load_ckpt_model=True ${additional_options}", "validate_ckpt.sh": "#!/usr/bin/env bash\nmy_dir=\"$(dirname\"$0\")\"\n. $my_dir/set_env.sh\necho\"MMAR_ROOT set to $MMAR_ROOT\"\nadditional_options=\"$*\"\nCONFIG_FILE=config/config_validation.json\nENVIRONMENT_FILE=config/environment.json\npython3 -u -m medl.apps.evaluate -m $MMAR_ROOT -c $CONFIG_FILE -e $ENVIRONMENT_FILE --set print_conf=True multi_gpu=False dont_load_ts_model=True dont_load_ckpt_model=False ${additional_options}", }) # FIXME: to keep this example simple, we don't add commands, docs, inference, etc. # build MMAR mmar.build() # test load MMAR test_mmar = MMAR() test_mmar.load(root_path="spleen_segmentation") print(test_mmar.get_train_config().get_config()) print(test_mmar.get_validate_config().get_config()) if __name__ == "__main__": main()

© Copyright 2021, NVIDIA. Last updated on Feb 2, 2023.