Create chest x-ray classification MMAR with MMAR API
import json
from ignite.metrics import Accuracy
from medl.tools.mmar_creator.component import Component
from medl.tools.mmar_creator.mmar import MMAR
from medl.tools.mmar_creator.train_config import TrainConfig
from medl.tools.mmar_creator.validate_config import ValidateConfig
from monai.data import CacheDataset, DataLoader, Dataset, PersistentDataset
from monai.engines import SupervisedEvaluator, SupervisedTrainer
from monai.handlers import (CheckpointLoader, CheckpointSaver,
ClassificationSaver, LrScheduleHandler,
MetricsSaver, StatsHandler,
TensorBoardStatsHandler, ValidationHandler)
from monai.handlers.roc_auc import ROCAUC
from monai.inferers import SimpleInferer
from monai.networks.nets import DenseNet
from monai.transforms import (Activationsd, AddChanneld, AsDiscreted,
CastToTyped, CopyItemsd, LoadImaged,
NormalizeIntensityd, RandRotated,
RandSpatialCropd, Resized, SaveImaged,
SplitChanneld, ToNumpyd, ToTensord)
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
def main():
plco_datalist_path = "local_path_to/plco.json"
# define train components
loss = Component(name=BCEWithLogitsLoss)
optimizer = Component(name=Adam, args={"params": "#@model.parameters()", "lr": "{learning_rate}", "weight_decay": 1e-5, "amsgrad": False})
train_model = Component(name=DenseNet, args={"init_features": 64, "growth_rate": 32, "block_config": [6, 12, 24, 16], "spatial_dims": 2, "in_channels": "{INPUT_CHANNELS}", "out_channels": "{OUTPUT_CHANNELS}"})
lr_scheduler = Component(name=StepLR, args={"optimizer": "@optimizer", "step_size": 40, "gamma": 0.1})
# define pre_transforms
load_image = Component(name=LoadImaged, args={"keys": ["image"]})
add_channel = Component(name=AddChanneld, args={"keys": ["image"]})
rand_spatial_crop = Component(name=RandSpatialCropd, args={"keys": "image", "roi_size": [230, 230], "random_center": True, "random_size": True})
resize = Component(name=Resized, args={"keys": ["image"], "spatial_size": [256, 256]})
rand_rotate = Component(name=RandRotated, args={"keys": ["image"], "range_x": 7})
normalize_intensity = Component(name=NormalizeIntensityd, args={"keys": ["image"], "subtrahend": 2876.37, "divisor": 883, "dtype": "float32"})
to_numpy = Component(name=ToNumpyd, args={"keys": ["label"]})
cast_to_type = Component(name=CastToTyped, args={"keys": ["label"], "dtype": "float32"})
to_tensor = Component(name=ToTensord, args={"keys": ["image", "label"]})
# define train dataset and dataloader
train_dataset = Component(
name=CacheDataset,
vars={"data_list_file_path": "{DATASET_JSON}", "data_file_base_dir": "{DATA_ROOT}", "data_list_key": "{TRAIN_DATALIST_KEY}"},
args={"transform": "@pre_transforms", "cache_num": 128, "cache_rate": 0.2, "num_workers": 5},
)
train_dataloader = Component(name=DataLoader, args={"dataset": "@dataset", "batch_size": 20, "shuffle": True, "num_workers": 5})
# define train inferer
train_inferer = Component(name=SimpleInferer)
# define train handlers
checkpoint_loader = Component(name=CheckpointLoader, vars={"disabled": "{dont_load_ckpt_model}"}, args={"load_path": "{MMAR_CKPT}", "load_dict": {"model": "@model"}})
lr_scheduler_handler = Component(name=LrScheduleHandler, args={"lr_scheduler": "@lr_scheduler", "print_lr": True})
validation_handler = Component(name=ValidationHandler, args={"validator": "@evaluator", "epoch_level": True, "interval": "{num_interval_per_valid}"})
train_checkpoint_saver = Component(name=CheckpointSaver, vars={"rank": 0}, args={"save_dir": "{MMAR_CKPT_DIR}", "save_dict": {"model": "@model", "optimizer": "@optimizer", "lr_scheduler": "@lr_scheduler", "train_conf": "@conf"}, "save_final": True, "save_interval": 400})
train_stats_handler = Component(name=StatsHandler, vars={"rank": 0}, args={"tag_name": "train_loss", "output_transform": "#monai.handlers.from_engine(['loss'], first=True)"})
train_tb_stats_handler = Component(name=TensorBoardStatsHandler, vars={"rank": 0}, args={"log_dir": "{MMAR_CKPT_DIR}", "tag_name": "train_loss", "output_transform": "#monai.handlers.from_engine(['loss'], first=True)"})
# define post transforms
activation = Component(name=Activationsd, args={"keys": "pred", "sigmoid": True})
copy_items = Component(name=CopyItemsd, args={"times": 1, "keys": "pred", "names": ["binary_preds"]})
as_discrete = Component(name=AsDiscreted, args={"keys": ["pred"], "argmax": [True], "to_onehot": 15})
save_image = Component(name=SaveImaged, args={"keys": "pred", "meta_keys": "pred_meta_dict", "output_dir": "{MMAR_EVAL_OUTPUT_PATH}", "resample": False, "squeeze_end_dims": True})
# define train metrics
train_avg_auc = Component(name=ROCAUC, vars={"log_label": "train_avg_auc"}, args={"output_transform": "#monai.handlers.from_engine(['binary_preds', 'label'])"})
train_acc = Component(name=Accuracy, vars={"log_label": "train_acc"}, args={"output_transform": "#monai.handlers.from_engine(['pred', 'label'])"})
# define trainer
trainer = Component(
name=SupervisedTrainer,
args={
"max_epochs": "{epochs}",
"device": "cuda",
"train_data_loader": "@dataloader",
"network": "@model",
"loss_function": "@loss",
"optimizer": "@optimizer",
"inferer": "@inferer",
"postprocessing": "@post_transforms",
"key_train_metric": "@key_metric",
"additional_metrics": "@additional_metrics",
"train_handlers": "@handlers",
"amp": "{amp}",
}
)
# define validation dataset and dataloader
val_dataset = Component(
name=PersistentDataset,
vars={"data_list_file_path": "{DATASET_JSON}", "data_file_base_dir": "{DATA_ROOT}", "data_list_key": "{VAL_DATALIST_KEY}"},
args={"transform": "@pre_transforms", "cache_dir": "val_cache"},
)
val_dataloader = Component(name=DataLoader, args={"dataset": "@dataset", "batch_size": 20, "shuffle": False, "num_workers": 4})
# define validation inferer
val_inferer = Component(name=SimpleInferer)
# define validation handlers
val_stats_handler = Component(name=StatsHandler, vars={"rank": 0}, args={"output_transform": "lambda x: None"})
val_tb_stats_handler = Component(name=TensorBoardStatsHandler, vars={"rank": 0}, args={"log_dir": "{MMAR_CKPT_DIR}", "output_transform": "lambda x: None"})
val_checkpoint_saver = Component(name=CheckpointSaver, vars={"rank": 0}, args={"save_dir": "{MMAR_CKPT_DIR}", "save_dict": {"model": "@model", "train_conf": "@conf"}, "save_key_metric": True})
# define validation post transforms
val_copy_items = Component(name=CopyItemsd, args={"times": 1, "keys": "pred", "names": ["accur_preds"]})
split_channel = Component(name=SplitChanneld, args={"keys": ["pred", "label"], "output_postfixes": ["Nodule",
"Mass",
"Distortion_pulmonary_architecture",
"Pleural_based_mass",
"Granuloma",
"Fluid_in_pleural_space",
"Right_hilar_abnormality",
"Left_hilar_abnormality",
"Major_atelectasis",
"Infiltrate",
"Scarring",
"Pleural_fibrosis",
"Bone_soft_tissue_lesion",
"Cardiac_abnormality",
"COPD"]})
val_as_discrete = Component(name=AsDiscreted, args={"keys": ["accur_preds"], "argmax": [True], "to_onehot": 15})
# define validation metrics
val_key_metric = Component(name=ROCAUC, vars={"log_label": "Average_AUC"}, args={"average": "macro", "output_transform": "#monai.handlers.from_engine(['pred', 'label'])"})
val_additional_metrics = [Component(name=Accuracy, vars={"log_label": "val_acc"}, args={"output_transform": "#monai.handlers.from_engine(['accur_preds', 'label'])"}),
Component(name=ROCAUC, vars={"log_label": "Nodule"}, args={"output_transform": "#monai.handlers.from_engine(['pred_Nodule', 'label_Nodule'])"}),
Component(name=ROCAUC, vars={"log_label": "Mass"}, args={"output_transform": "#monai.handlers.from_engine(['pred_Mass', 'label_Mass'])"}),
Component(name=ROCAUC, vars={"log_label": "Distortion_pulmonary_architecture"}, args={"output_transform": "#monai.handlers.from_engine(['pred_Distortion_pulmonary_architecture', 'label_Distortion_pulmonary_architecture'])"}),
Component(name=ROCAUC, vars={"log_label": "Pleural_based_mass"}, args={"output_transform": "#monai.handlers.from_engine(['pred_Pleural_based_mass', 'label_Pleural_based_mass'])"}),
Component(name=ROCAUC, vars={"log_label": "Granuloma"}, args={"output_transform": "#monai.handlers.from_engine(['pred_Granuloma', 'label_Granuloma'])"}),
Component(name=ROCAUC, vars={"log_label": "Fluid_in_pleural_space"}, args={"output_transform": "#monai.handlers.from_engine(['pred_Fluid_in_pleural_space', 'label_Fluid_in_pleural_space'])"}),
Component(name=ROCAUC, vars={"log_label": "Right_hilar_abnormality"}, args={"output_transform": "#monai.handlers.from_engine(['pred_Right_hilar_abnormality', 'label_Right_hilar_abnormality'])"}),
Component(name=ROCAUC, vars={"log_label": "Left_hilar_abnormality"}, args={"output_transform": "#monai.handlers.from_engine(['pred_Left_hilar_abnormality', 'label_Left_hilar_abnormality'])"}),
Component(name=ROCAUC, vars={"log_label": "Major_atelectasis"}, args={"output_transform": "#monai.handlers.from_engine(['pred_Major_atelectasis', 'label_Major_atelectasis'])"}),
Component(name=ROCAUC, vars={"log_label": "Infiltrate"}, args={"output_transform": "#monai.handlers.from_engine(['pred_Infiltrate', 'label_Infiltrate'])"}),
Component(name=ROCAUC, vars={"log_label": "Scarring"}, args={"output_transform": "#monai.handlers.from_engine(['pred_Scarring', 'label_Scarring'])"}),
Component(name=ROCAUC, vars={"log_label": "Pleural_fibrosis"}, args={"output_transform": "#monai.handlers.from_engine(['pred_Pleural_fibrosis', 'label_Pleural_fibrosis'])"}),
Component(name=ROCAUC, vars={"log_label": "Bone_soft_tissue_lesion"}, args={"output_transform": "#monai.handlers.from_engine(['pred_Bone_soft_tissue_lesion', 'label_Bone_soft_tissue_lesion'])"}),
Component(name=ROCAUC, vars={"log_label": "Cardiac_abnormality"}, args={"output_transform": "#monai.handlers.from_engine(['pred_Cardiac_abnormality', 'label_Cardiac_abnormality'])"}),
Component(name=ROCAUC, vars={"log_label": "COPD"}, args={"output_transform": "#monai.handlers.from_engine(['pred_COPD', 'label_COPD'])"})]
# define evaluator
evaluator = Component(
name=SupervisedEvaluator,
args={
"device": "cuda",
"val_data_loader": "@dataloader",
"network": "@model",
"inferer": "@inferer",
"postprocessing": "@post_transforms",
"key_val_metric": "@key_metric",
"additional_metrics": "@additional_metrics",
"val_handlers": "@handlers",
"amp": "{amp}",
}
)
# define evaluation dataset
eval_dataset = Component(
name=Dataset,
vars={"data_list_file_path": "{DATASET_JSON}", "data_file_base_dir": "{DATA_ROOT}", "data_list_key": "{VAL_DATALIST_KEY}"},
args={"transform": "@pre_transforms"},
)
eval_dataloader = Component(name=DataLoader, args={"dataset": "@dataset", "batch_size": 1, "shuffle": False, "num_workers": 4})
# define evaluation handlers
eval_classification_saver = Component(name=ClassificationSaver, args={"output_dir": "{MMAR_EVAL_OUTPUT_PATH}", "batch_transform": "#monai.handlers.from_engine(['image_meta_dict'])", "output_transform": "#monai.handlers.from_engine(['pred'])"})
eval_metrics_saver = Component(name=MetricsSaver, args={"save_dir": "{MMAR_EVAL_OUTPUT_PATH}", "metrics": "*", "metric_details": None, "batch_transform": None, "summary_ops": None, "save_rank": 0})
# create train config based on above components
train_conf = TrainConfig()
# set variables in config
train_conf.add_vars(epochs=40, num_interval_per_valid=1, learning_rate=1e-4, multi_gpu=False, amp=True, cudnn_benchmark=False, dont_load_ckpt_model=True)
train_conf.add_loss(loss)
train_conf.add_optimizer(optimizer)
train_conf.add_model(train_model)
train_conf.add_lr_scheduler(lr_scheduler)
train_conf.add_train_pre_transform([
load_image,
add_channel,
rand_spatial_crop,
resize,
rand_rotate,
normalize_intensity,
to_numpy,
cast_to_type,
to_tensor,
])
train_conf.add_train_dataset(train_dataset)
train_conf.add_train_dataloader(train_dataloader)
train_conf.add_train_inferer(train_inferer)
train_conf.add_train_handler([
checkpoint_loader,
lr_scheduler_handler,
validation_handler,
train_checkpoint_saver,
train_stats_handler,
train_tb_stats_handler,
])
train_conf.add_train_post_transform([activation, copy_items, as_discrete])
train_conf.add_train_key_metric(train_avg_auc)
train_conf.add_train_additional_metric(train_acc)
train_conf.add_trainer(trainer)
train_conf.add_val_pre_transform([
load_image,
add_channel,
resize,
normalize_intensity,
to_numpy,
cast_to_type,
to_tensor,
])
train_conf.add_val_dataset(val_dataset)
train_conf.add_val_dataloader(val_dataloader)
train_conf.add_val_inferer(val_inferer)
train_conf.add_val_handler([
val_stats_handler,
val_tb_stats_handler,
val_checkpoint_saver,
])
train_conf.add_val_post_transform([activation, val_copy_items, split_channel, val_as_discrete])
train_conf.add_val_key_metric(val_key_metric)
train_conf.add_val_additional_metric(val_additional_metrics)
train_conf.add_evaluator(evaluator)
# create validation config based on above components
eval_conf = ValidateConfig()
# set variables in the config
eval_conf.add_vars(multi_gpu=False, amp=False, dont_load_ts_model=False, dont_load_ckpt_model=True)
eval_conf.add_model([
Component(vars={"ts_path": "{MMAR_TORCHSCRIPT}", "disabled": "{dont_load_ts_model}"}),
Component(vars={"ckpt_path": "{MMAR_CKPT}", "disabled": "{dont_load_ckpt_model}"}),
])
eval_conf.add_pre_transform([
load_image,
add_channel,
resize,
normalize_intensity,
to_numpy,
cast_to_type,
to_tensor,
])
eval_conf.add_dataset(eval_dataset)
eval_conf.add_dataloader(eval_dataloader)
eval_conf.add_inferer(val_inferer)
eval_conf.add_handler([
val_stats_handler,
checkpoint_loader,
eval_classification_saver,
eval_metrics_saver,
])
eval_conf.add_post_transform([activation, val_copy_items, split_channel, val_as_discrete])
eval_conf.add_key_metric(val_key_metric)
eval_conf.add_additional_metric(val_additional_metrics)
eval_conf.add_evaluator(evaluator)
# create MMAR
mmar = MMAR(root_path="pt_chest_xray_classification", datalist_name="plco.json")
mmar.set_train_config(train_conf)
mmar.set_validate_config(eval_conf)
mmar.set_resources({"log.config": "[loggers]\nkeys=root,modelLogger\n[handlers]\nkeys=consoleHandler\n[formatters]\nkeys=fullFormatter\n[logger_root]\nlevel=INFO\nhandlers=consoleHandler\n[logger_modelLogger]\nlevel=DEBUG\nhandlers=consoleHandler\nqualname=modelLogger\npropagate=0\n[handler_consoleHandler]\nclass=StreamHandler\nlevel=DEBUG\nformatter=fullFormatter\nargs=(sys.stdout,)\n[formatter_fullFormatter]\nformat=%(asctime)s-%(name)s-%(levelname)s-%(message)s"})
mmar.set_environment(env={
"CLARA_TRAIN_VERSION": "4.1.0",
"DATA_ROOT": "/workspace/data/CXR/PLCO/PLCO_256_original",
"DATASET_JSON": "config/plco.json",
"PROCESSING_TASK": "classification",
"TRAIN_DATALIST_KEY": "training",
"VAL_DATALIST_KEY": "validation",
"INFER_DATALIST_KEY": "test",
"MMAR_EVAL_OUTPUT_PATH": "eval",
"MMAR_CKPT_DIR": "models",
"MMAR_CKPT": "models/model.pt",
"MMAR_TORCHSCRIPT": "models/model.ts",
"INPUT_CHANNELS": 1,
"OUTPUT_CHANNELS": 15
})
# directly using plco.json for contents of datalist because it is over 17MB, file must already exist at path
with open(plco_datalist_path) as f:
plco_datalist = json.load(f)
mmar.set_datalist(datalist=plco_datalist)
mmar.set_commands({
"set_env.sh": "#!/usr/bin/env bash\nexport PYTHONPATH=\"$PYTHONPATH:/opt/nvidia\"\nDIR=\"$( cd\"$( dirname\"${BASH_SOURCE[0]}\")\">/dev/null 2>&1 && pwd )\"\nexport MMAR_ROOT=${DIR}/..",
"train.sh": "#!/usr/bin/env bash\nmy_dir=\"$(dirname\"$0\")\"\n. $my_dir/set_env.sh\necho\"MMAR_ROOT set to $MMAR_ROOT\"\nadditional_options=\"$*\"\nCONFIG_FILE=config/config_train.json\nENVIRONMENT_FILE=config/environment.json\npython3 -u -m medl.apps.train -m $MMAR_ROOT -c $CONFIG_FILE -e $ENVIRONMENT_FILE --write_train_stats --set print_conf=True multi_gpu=False cudnn_benchmark=False dont_load_ckpt_model=True ${additional_options}",
"validate_ckpt.sh": "#!/usr/bin/env bash\nmy_dir=\"$(dirname\"$0\")\"\n. $my_dir/set_env.sh\necho\"MMAR_ROOT set to $MMAR_ROOT\"\nadditional_options=\"$*\"\nCONFIG_FILE=config/config_validation.json\nENVIRONMENT_FILE=config/environment.json\npython3 -u -m medl.apps.evaluate -m $MMAR_ROOT -c $CONFIG_FILE -e $ENVIRONMENT_FILE --set print_conf=True multi_gpu=False dont_load_ts_model=True dont_load_ckpt_model=False ${additional_options}",
})
# FIXME: to keep this example simple, we don't add commands, docs, inference, etc.
# build MMAR
mmar.build()
# test load MMAR
test_mmar = MMAR()
test_mmar.load(root_path="pt_chest_xray_classification")
print(test_mmar.get_train_config().get_config())
print(test_mmar.get_validate_config().get_config())
if __name__ == "__main__":
main()