Bring your own components (BYOC)
Clara Train allows researchers to solve new/different problems and innovate by writing their own components in a modular way. In order to do this, users can write their own components in python files then point to these files in the train_config.json file by providing the paths for the new components.
Users can look at the component implementations in MONAI and use any as a template to modify for creating their own components.
For basic components such as networks, loss functions, etc., they do not have to be related to ignite or MONAI workflows as long as the component conforms to the expected PyTorch APIs:
Put the BYOC code into the MMAR/custom folder, for example: my_model.py into a directory named “custom” in the MMAR.
Config your $PYTHONPATH to include the above path (for example, setting
export PYTHONPATH="$PYTHONPATH:/opt/nvidia:$MMAR_ROOT/custom"
in set_env.sh).Use your custom component in the MMAR config with “path”: “” in place of “name”: “”.
"model": {
"path": "my_model.Model",
"args": {
...
}
}
For other custom logic, we defined 3 levels of APIs to achieve a wide range of customizability:
Handlers: trigger at many Events: started, completed, epoch_started/completed, etc. There are already rich handlers in MONAI: https://docs.monai.io/en/latest/handlers.html
Iteration: write arbitrary computation progress for every iteration.
Extend SupervisedTrainer: write arbitrary trainer to do very complicated logic.
See an example notebook for BYOC.
Bring extra input data for the network
class ExtraInputNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.extra_tensor = None
def forward(self, image):
return image + self.extra_tensor
class ExtraInputHandler:
def attach(self, engine):
engine.add_event_handler(Events.ITERATION_STARTED, self)
def __call__(self, engine):
engine.network.extra_tensor = engine.state.batch["extra"].to(engine.state.device)
Keep init weights of the network for loss
class LossWithWeights(torch.nn.Module):
def __init__(self):
super().__init__()
self.init_weights = None
def forward(self, pred, label):
return pred - label + self.init_weights # fake code
class SaveInitWeightsHandler:
def attach(self, engine):
engine.add_event_handler(Events.STARTED, self)
def __call__(self, engine):
engine.loss_function.init_weights = copy.deepcopy(engine.network.state_dict())
Add noise to predictions before loss
from monai.engines import IterationEvents
class AddNoiseHandler:
def attach(self, engine):
engine.add_event_handler(IterationEvents.FORWARD_COMPLETED, self)
def __call__(self, engine):
engine.state.output["pred"] += 1.0
Shutdown program if loss < 0.5
class ShutDownHandler:
def attach(self, engine):
engine.add_event_handler(Events.ITERATION_COMPLETED, self)
def __call__(self, engine):
if engine.state.output["loss"] < 0.5:
engine.terminate()
Do optimizer.step() every 10 iterations
from medl.engines.iterations import Iteration
class AccumulateIteration(Iteration):
def __call__(self, engine, batchdata):
batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking)
engine.network.train()
engine.optimizer.zero_grad()
pred = engine.inferer(batch[0], engine.network)
loss = engine.loss_function(pred, batch[1])
loss.backward()
if engine.state.iteration % 10 == 0:
engine.optimizer.step()
return {"image": batch[0], "label": batch[1], "pred": pred, "loss": loss}
Customize SupervisedTrainer
from ignite.engine import EventEnum
from monai.engines import SupervisedTrainer
class CustomEvents(EventEnum):
ACCUMULATE_COMPLETED = "accumulate_completed"
class CustomTrainer(SupervisedTrainer):
def _register_additional_events(self):
super()._register_additional_events()
self.register_events(*CustomEvents)
If you have needs for custom training logic that is not covered by the training workflow, you can bring your own workflow (BYOW).
For custom training logic, the developer of the MMAR is responsible for what should be in the JSON config and what
should be in BYOC python code. For an extreme example, all the logic can be in a BYOC trainer, then this trainer is set
in config_train.json
:
{
"epochs": 1260,
"learning_rate": 2e-4,
"amp": true,
"determinism": {
"random_seed": 0
},
"train": {
"trainer": {
"path": "byoc_trainer.Trainer",
"args": {
"model_path": "{MMAR_CKPT}",
"amp": "{amp}",
"num_epochs": "{epochs}",
"lr": "{learning_rate}"
}
}
}
}