Create chest x-ray classification MMAR with MMAR API

Copy
Copied!
            

import json from ignite.metrics import Accuracy from medl.tools.mmar_creator.component import Component from medl.tools.mmar_creator.mmar import MMAR from medl.tools.mmar_creator.train_config import TrainConfig from medl.tools.mmar_creator.validate_config import ValidateConfig from monai.data import CacheDataset, DataLoader, Dataset, PersistentDataset from monai.engines import SupervisedEvaluator, SupervisedTrainer from monai.handlers import (CheckpointLoader, CheckpointSaver, ClassificationSaver, LrScheduleHandler, MetricsSaver, StatsHandler, TensorBoardStatsHandler, ValidationHandler) from monai.handlers.roc_auc import ROCAUC from monai.inferers import SimpleInferer from monai.networks.nets import DenseNet from monai.transforms import (Activationsd, AddChanneld, AsDiscreted, CastToTyped, CopyItemsd, LoadImaged, NormalizeIntensityd, RandRotated, RandSpatialCropd, Resized, SaveImaged, SplitChanneld, ToNumpyd, ToTensord) from torch.nn import BCEWithLogitsLoss from torch.optim import Adam from torch.optim.lr_scheduler import StepLR def main(): plco_datalist_path = "local_path_to/plco.json" # define train components loss = Component(name=BCEWithLogitsLoss) optimizer = Component(name=Adam, args={"params": "#@model.parameters()", "lr": "{learning_rate}", "weight_decay": 1e-5, "amsgrad": False}) train_model = Component(name=DenseNet, args={"init_features": 64, "growth_rate": 32, "block_config": [6, 12, 24, 16], "spatial_dims": 2, "in_channels": "{INPUT_CHANNELS}", "out_channels": "{OUTPUT_CHANNELS}"}) lr_scheduler = Component(name=StepLR, args={"optimizer": "@optimizer", "step_size": 40, "gamma": 0.1}) # define pre_transforms load_image = Component(name=LoadImaged, args={"keys": ["image"]}) add_channel = Component(name=AddChanneld, args={"keys": ["image"]}) rand_spatial_crop = Component(name=RandSpatialCropd, args={"keys": "image", "roi_size": [230, 230], "random_center": True, "random_size": True}) resize = Component(name=Resized, args={"keys": ["image"], "spatial_size": [256, 256]}) rand_rotate = Component(name=RandRotated, args={"keys": ["image"], "range_x": 7}) normalize_intensity = Component(name=NormalizeIntensityd, args={"keys": ["image"], "subtrahend": 2876.37, "divisor": 883, "dtype": "float32"}) to_numpy = Component(name=ToNumpyd, args={"keys": ["label"]}) cast_to_type = Component(name=CastToTyped, args={"keys": ["label"], "dtype": "float32"}) to_tensor = Component(name=ToTensord, args={"keys": ["image", "label"]}) # define train dataset and dataloader train_dataset = Component( name=CacheDataset, vars={"data_list_file_path": "{DATASET_JSON}", "data_file_base_dir": "{DATA_ROOT}", "data_list_key": "{TRAIN_DATALIST_KEY}"}, args={"transform": "@pre_transforms", "cache_num": 128, "cache_rate": 0.2, "num_workers": 5}, ) train_dataloader = Component(name=DataLoader, args={"dataset": "@dataset", "batch_size": 20, "shuffle": True, "num_workers": 5}) # define train inferer train_inferer = Component(name=SimpleInferer) # define train handlers checkpoint_loader = Component(name=CheckpointLoader, vars={"disabled": "{dont_load_ckpt_model}"}, args={"load_path": "{MMAR_CKPT}", "load_dict": {"model": "@model"}}) lr_scheduler_handler = Component(name=LrScheduleHandler, args={"lr_scheduler": "@lr_scheduler", "print_lr": True}) validation_handler = Component(name=ValidationHandler, args={"validator": "@evaluator", "epoch_level": True, "interval": "{num_interval_per_valid}"}) train_checkpoint_saver = Component(name=CheckpointSaver, vars={"rank": 0}, args={"save_dir": "{MMAR_CKPT_DIR}", "save_dict": {"model": "@model", "optimizer": "@optimizer", "lr_scheduler": "@lr_scheduler", "train_conf": "@conf"}, "save_final": True, "save_interval": 400}) train_stats_handler = Component(name=StatsHandler, vars={"rank": 0}, args={"tag_name": "train_loss", "output_transform": "#monai.handlers.from_engine(['loss'], first=True)"}) train_tb_stats_handler = Component(name=TensorBoardStatsHandler, vars={"rank": 0}, args={"log_dir": "{MMAR_CKPT_DIR}", "tag_name": "train_loss", "output_transform": "#monai.handlers.from_engine(['loss'], first=True)"}) # define post transforms activation = Component(name=Activationsd, args={"keys": "pred", "sigmoid": True}) copy_items = Component(name=CopyItemsd, args={"times": 1, "keys": "pred", "names": ["binary_preds"]}) as_discrete = Component(name=AsDiscreted, args={"keys": ["pred"], "argmax": [True], "to_onehot": 15}) save_image = Component(name=SaveImaged, args={"keys": "pred", "meta_keys": "pred_meta_dict", "output_dir": "{MMAR_EVAL_OUTPUT_PATH}", "resample": False, "squeeze_end_dims": True}) # define train metrics train_avg_auc = Component(name=ROCAUC, vars={"log_label": "train_avg_auc"}, args={"output_transform": "#monai.handlers.from_engine(['binary_preds', 'label'])"}) train_acc = Component(name=Accuracy, vars={"log_label": "train_acc"}, args={"output_transform": "#monai.handlers.from_engine(['pred', 'label'])"}) # define trainer trainer = Component( name=SupervisedTrainer, args={ "max_epochs": "{epochs}", "device": "cuda", "train_data_loader": "@dataloader", "network": "@model", "loss_function": "@loss", "optimizer": "@optimizer", "inferer": "@inferer", "postprocessing": "@post_transforms", "key_train_metric": "@key_metric", "additional_metrics": "@additional_metrics", "train_handlers": "@handlers", "amp": "{amp}", } ) # define validation dataset and dataloader val_dataset = Component( name=PersistentDataset, vars={"data_list_file_path": "{DATASET_JSON}", "data_file_base_dir": "{DATA_ROOT}", "data_list_key": "{VAL_DATALIST_KEY}"}, args={"transform": "@pre_transforms", "cache_dir": "val_cache"}, ) val_dataloader = Component(name=DataLoader, args={"dataset": "@dataset", "batch_size": 20, "shuffle": False, "num_workers": 4}) # define validation inferer val_inferer = Component(name=SimpleInferer) # define validation handlers val_stats_handler = Component(name=StatsHandler, vars={"rank": 0}, args={"output_transform": "lambda x: None"}) val_tb_stats_handler = Component(name=TensorBoardStatsHandler, vars={"rank": 0}, args={"log_dir": "{MMAR_CKPT_DIR}", "output_transform": "lambda x: None"}) val_checkpoint_saver = Component(name=CheckpointSaver, vars={"rank": 0}, args={"save_dir": "{MMAR_CKPT_DIR}", "save_dict": {"model": "@model", "train_conf": "@conf"}, "save_key_metric": True}) # define validation post transforms val_copy_items = Component(name=CopyItemsd, args={"times": 1, "keys": "pred", "names": ["accur_preds"]}) split_channel = Component(name=SplitChanneld, args={"keys": ["pred", "label"], "output_postfixes": ["Nodule", "Mass", "Distortion_pulmonary_architecture", "Pleural_based_mass", "Granuloma", "Fluid_in_pleural_space", "Right_hilar_abnormality", "Left_hilar_abnormality", "Major_atelectasis", "Infiltrate", "Scarring", "Pleural_fibrosis", "Bone_soft_tissue_lesion", "Cardiac_abnormality", "COPD"]}) val_as_discrete = Component(name=AsDiscreted, args={"keys": ["accur_preds"], "argmax": [True], "to_onehot": 15}) # define validation metrics val_key_metric = Component(name=ROCAUC, vars={"log_label": "Average_AUC"}, args={"average": "macro", "output_transform": "#monai.handlers.from_engine(['pred', 'label'])"}) val_additional_metrics = [Component(name=Accuracy, vars={"log_label": "val_acc"}, args={"output_transform": "#monai.handlers.from_engine(['accur_preds', 'label'])"}), Component(name=ROCAUC, vars={"log_label": "Nodule"}, args={"output_transform": "#monai.handlers.from_engine(['pred_Nodule', 'label_Nodule'])"}), Component(name=ROCAUC, vars={"log_label": "Mass"}, args={"output_transform": "#monai.handlers.from_engine(['pred_Mass', 'label_Mass'])"}), Component(name=ROCAUC, vars={"log_label": "Distortion_pulmonary_architecture"}, args={"output_transform": "#monai.handlers.from_engine(['pred_Distortion_pulmonary_architecture', 'label_Distortion_pulmonary_architecture'])"}), Component(name=ROCAUC, vars={"log_label": "Pleural_based_mass"}, args={"output_transform": "#monai.handlers.from_engine(['pred_Pleural_based_mass', 'label_Pleural_based_mass'])"}), Component(name=ROCAUC, vars={"log_label": "Granuloma"}, args={"output_transform": "#monai.handlers.from_engine(['pred_Granuloma', 'label_Granuloma'])"}), Component(name=ROCAUC, vars={"log_label": "Fluid_in_pleural_space"}, args={"output_transform": "#monai.handlers.from_engine(['pred_Fluid_in_pleural_space', 'label_Fluid_in_pleural_space'])"}), Component(name=ROCAUC, vars={"log_label": "Right_hilar_abnormality"}, args={"output_transform": "#monai.handlers.from_engine(['pred_Right_hilar_abnormality', 'label_Right_hilar_abnormality'])"}), Component(name=ROCAUC, vars={"log_label": "Left_hilar_abnormality"}, args={"output_transform": "#monai.handlers.from_engine(['pred_Left_hilar_abnormality', 'label_Left_hilar_abnormality'])"}), Component(name=ROCAUC, vars={"log_label": "Major_atelectasis"}, args={"output_transform": "#monai.handlers.from_engine(['pred_Major_atelectasis', 'label_Major_atelectasis'])"}), Component(name=ROCAUC, vars={"log_label": "Infiltrate"}, args={"output_transform": "#monai.handlers.from_engine(['pred_Infiltrate', 'label_Infiltrate'])"}), Component(name=ROCAUC, vars={"log_label": "Scarring"}, args={"output_transform": "#monai.handlers.from_engine(['pred_Scarring', 'label_Scarring'])"}), Component(name=ROCAUC, vars={"log_label": "Pleural_fibrosis"}, args={"output_transform": "#monai.handlers.from_engine(['pred_Pleural_fibrosis', 'label_Pleural_fibrosis'])"}), Component(name=ROCAUC, vars={"log_label": "Bone_soft_tissue_lesion"}, args={"output_transform": "#monai.handlers.from_engine(['pred_Bone_soft_tissue_lesion', 'label_Bone_soft_tissue_lesion'])"}), Component(name=ROCAUC, vars={"log_label": "Cardiac_abnormality"}, args={"output_transform": "#monai.handlers.from_engine(['pred_Cardiac_abnormality', 'label_Cardiac_abnormality'])"}), Component(name=ROCAUC, vars={"log_label": "COPD"}, args={"output_transform": "#monai.handlers.from_engine(['pred_COPD', 'label_COPD'])"})] # define evaluator evaluator = Component( name=SupervisedEvaluator, args={ "device": "cuda", "val_data_loader": "@dataloader", "network": "@model", "inferer": "@inferer", "postprocessing": "@post_transforms", "key_val_metric": "@key_metric", "additional_metrics": "@additional_metrics", "val_handlers": "@handlers", "amp": "{amp}", } ) # define evaluation dataset eval_dataset = Component( name=Dataset, vars={"data_list_file_path": "{DATASET_JSON}", "data_file_base_dir": "{DATA_ROOT}", "data_list_key": "{VAL_DATALIST_KEY}"}, args={"transform": "@pre_transforms"}, ) eval_dataloader = Component(name=DataLoader, args={"dataset": "@dataset", "batch_size": 1, "shuffle": False, "num_workers": 4}) # define evaluation handlers eval_classification_saver = Component(name=ClassificationSaver, args={"output_dir": "{MMAR_EVAL_OUTPUT_PATH}", "batch_transform": "#monai.handlers.from_engine(['image_meta_dict'])", "output_transform": "#monai.handlers.from_engine(['pred'])"}) eval_metrics_saver = Component(name=MetricsSaver, args={"save_dir": "{MMAR_EVAL_OUTPUT_PATH}", "metrics": "*", "metric_details": None, "batch_transform": None, "summary_ops": None, "save_rank": 0}) # create train config based on above components train_conf = TrainConfig() # set variables in config train_conf.add_vars(epochs=40, num_interval_per_valid=1, learning_rate=1e-4, multi_gpu=False, amp=True, cudnn_benchmark=False, dont_load_ckpt_model=True) train_conf.add_loss(loss) train_conf.add_optimizer(optimizer) train_conf.add_model(train_model) train_conf.add_lr_scheduler(lr_scheduler) train_conf.add_train_pre_transform([ load_image, add_channel, rand_spatial_crop, resize, rand_rotate, normalize_intensity, to_numpy, cast_to_type, to_tensor, ]) train_conf.add_train_dataset(train_dataset) train_conf.add_train_dataloader(train_dataloader) train_conf.add_train_inferer(train_inferer) train_conf.add_train_handler([ checkpoint_loader, lr_scheduler_handler, validation_handler, train_checkpoint_saver, train_stats_handler, train_tb_stats_handler, ]) train_conf.add_train_post_transform([activation, copy_items, as_discrete]) train_conf.add_train_key_metric(train_avg_auc) train_conf.add_train_additional_metric(train_acc) train_conf.add_trainer(trainer) train_conf.add_val_pre_transform([ load_image, add_channel, resize, normalize_intensity, to_numpy, cast_to_type, to_tensor, ]) train_conf.add_val_dataset(val_dataset) train_conf.add_val_dataloader(val_dataloader) train_conf.add_val_inferer(val_inferer) train_conf.add_val_handler([ val_stats_handler, val_tb_stats_handler, val_checkpoint_saver, ]) train_conf.add_val_post_transform([activation, val_copy_items, split_channel, val_as_discrete]) train_conf.add_val_key_metric(val_key_metric) train_conf.add_val_additional_metric(val_additional_metrics) train_conf.add_evaluator(evaluator) # create validation config based on above components eval_conf = ValidateConfig() # set variables in the config eval_conf.add_vars(multi_gpu=False, amp=False, dont_load_ts_model=False, dont_load_ckpt_model=True) eval_conf.add_model([ Component(vars={"ts_path": "{MMAR_TORCHSCRIPT}", "disabled": "{dont_load_ts_model}"}), Component(vars={"ckpt_path": "{MMAR_CKPT}", "disabled": "{dont_load_ckpt_model}"}), ]) eval_conf.add_pre_transform([ load_image, add_channel, resize, normalize_intensity, to_numpy, cast_to_type, to_tensor, ]) eval_conf.add_dataset(eval_dataset) eval_conf.add_dataloader(eval_dataloader) eval_conf.add_inferer(val_inferer) eval_conf.add_handler([ val_stats_handler, checkpoint_loader, eval_classification_saver, eval_metrics_saver, ]) eval_conf.add_post_transform([activation, val_copy_items, split_channel, val_as_discrete]) eval_conf.add_key_metric(val_key_metric) eval_conf.add_additional_metric(val_additional_metrics) eval_conf.add_evaluator(evaluator) # create MMAR mmar = MMAR(root_path="pt_chest_xray_classification", datalist_name="plco.json") mmar.set_train_config(train_conf) mmar.set_validate_config(eval_conf) mmar.set_resources({"log.config": "[loggers]\nkeys=root,modelLogger\n[handlers]\nkeys=consoleHandler\n[formatters]\nkeys=fullFormatter\n[logger_root]\nlevel=INFO\nhandlers=consoleHandler\n[logger_modelLogger]\nlevel=DEBUG\nhandlers=consoleHandler\nqualname=modelLogger\npropagate=0\n[handler_consoleHandler]\nclass=StreamHandler\nlevel=DEBUG\nformatter=fullFormatter\nargs=(sys.stdout,)\n[formatter_fullFormatter]\nformat=%(asctime)s-%(name)s-%(levelname)s-%(message)s"}) mmar.set_environment(env={ "CLARA_TRAIN_VERSION": "4.1.0", "DATA_ROOT": "/workspace/data/CXR/PLCO/PLCO_256_original", "DATASET_JSON": "config/plco.json", "PROCESSING_TASK": "classification", "TRAIN_DATALIST_KEY": "training", "VAL_DATALIST_KEY": "validation", "INFER_DATALIST_KEY": "test", "MMAR_EVAL_OUTPUT_PATH": "eval", "MMAR_CKPT_DIR": "models", "MMAR_CKPT": "models/model.pt", "MMAR_TORCHSCRIPT": "models/model.ts", "INPUT_CHANNELS": 1, "OUTPUT_CHANNELS": 15 }) # directly using plco.json for contents of datalist because it is over 17MB, file must already exist at path with open(plco_datalist_path) as f: plco_datalist = json.load(f) mmar.set_datalist(datalist=plco_datalist) mmar.set_commands({ "set_env.sh": "#!/usr/bin/env bash\nexport PYTHONPATH=\"$PYTHONPATH:/opt/nvidia\"\nDIR=\"$( cd\"$( dirname\"${BASH_SOURCE[0]}\")\">/dev/null 2>&1 && pwd )\"\nexport MMAR_ROOT=${DIR}/..", "train.sh": "#!/usr/bin/env bash\nmy_dir=\"$(dirname\"$0\")\"\n. $my_dir/set_env.sh\necho\"MMAR_ROOT set to $MMAR_ROOT\"\nadditional_options=\"$*\"\nCONFIG_FILE=config/config_train.json\nENVIRONMENT_FILE=config/environment.json\npython3 -u -m medl.apps.train -m $MMAR_ROOT -c $CONFIG_FILE -e $ENVIRONMENT_FILE --write_train_stats --set print_conf=True multi_gpu=False cudnn_benchmark=False dont_load_ckpt_model=True ${additional_options}", "validate_ckpt.sh": "#!/usr/bin/env bash\nmy_dir=\"$(dirname\"$0\")\"\n. $my_dir/set_env.sh\necho\"MMAR_ROOT set to $MMAR_ROOT\"\nadditional_options=\"$*\"\nCONFIG_FILE=config/config_validation.json\nENVIRONMENT_FILE=config/environment.json\npython3 -u -m medl.apps.evaluate -m $MMAR_ROOT -c $CONFIG_FILE -e $ENVIRONMENT_FILE --set print_conf=True multi_gpu=False dont_load_ts_model=True dont_load_ckpt_model=False ${additional_options}", }) # FIXME: to keep this example simple, we don't add commands, docs, inference, etc. # build MMAR mmar.build() # test load MMAR test_mmar = MMAR() test_mmar.load(root_path="pt_chest_xray_classification") print(test_mmar.get_train_config().get_config()) print(test_mmar.get_validate_config().get_config()) if __name__ == "__main__": main()

© Copyright 2021, NVIDIA. Last updated on Feb 2, 2023.