Bring your own Inference

By default, AIAA is using TritonInference. This means the inference requests are sending to Triton and let Triton leverage multiple GPUs.

You can also write your own inference procedure by extending the base class below, just make sure you put your custom inference in <AIAA workspace>/lib folder.

Interface

Below is the base class of Inference.

class Inference:
    @abstractmethod
    def inference(self, name, data, config: ModelConfig, triton_config):
        """Defines an inference procedure."""
        pass

    @abstractmethod
    def close(self):
        """Closes any resources you might have opened."""
        pass

Write a Custom Inference

Following is an example of a CustomInference.

import logging

import torch

from aiaa.configs.modelconfig import ModelConfig
from aiaa.inference.inference import Inference
from aiaa.utils.class_utils import instantiate_class


class CustomInference(Inference):

    def __init__(
        self,
        is_batched_data=False,
        network=None,
        device='cuda'
    ):
        self.network = network
        self.device = device
        self.is_batched_data = is_batched_data

        self.model = None

    def inference(self, name, data, config: ModelConfig, triton_config):
        logger = logging.getLogger(__name__)
        logger.info('Run CustomInference for: {}'.format(name))

        if self.model is None:
            self._init_context(config)

        input_key = config.get_inference_input()
        output_key = config.get_inference_output()
        logger.debug('Input Key: {}; Output Key: {}'.format(input_key, output_key))

        inputs = data[input_key] if input_key else data
        inputs = inputs if torch.is_tensor(inputs) else torch.from_numpy(inputs)
        inputs = inputs if self.is_batched_data else inputs[None]
        inputs = inputs.to(self.device)

        logger.info('Input Shape: {}'.format(inputs.shape))
        logger.info('Input Type: {}'.format(type(inputs)))

        outputs = self._simple_inference(inputs)
        logger.info('Output Shape: {}'.format(outputs.shape))

        outputs = outputs[0]
        data.update({output_key: outputs})
        return data

    def _init_context(self, config: ModelConfig):
        logger = logging.getLogger(__name__)

        if self.model:
            return

        if self.network is None:
            logger.info('Loading TorchScript Model from: {}'.format(config.get_path()))
            model = torch.jit.load(config.get_path())
        else:
            name = self.network['name']
            args = self.network['args']

            logger.info('Loading PyTorch Model Checkpoints from: {}'.format(config.get_path()))
            logger.info('Constructing PyTorch Model Network {}'.format(name))

            model = instantiate_class(name, args)
            model.load_state_dict(torch.load(config.get_path()))

        model.to(self.device)
        model.eval()
        self.model = model

    def _simple_inference(self, inputs):
        with torch.no_grad():
            outputs = self.model(inputs)
        return outputs

    def close(self):
        if self.model and hasattr(self.model, 'close'):
            self.model.close()
        self.model = None

Use a Custom Inference

Let’s save this custom inference in custom_inference.py and copy it into <AIAA workspace>/lib folder.

Then you can use it in AIAA config, for example:

{
  "version": 1,
  "type": "annotation",
  "labels": [
    "custom_organ"
  ],
  "description": "Custom Model to segment custom organ with user clicks",
  "pre_transforms": [
    {
      "name": "monai.transforms.LoadImaged",
      "args": {
        "keys": "image"
      }
    },
    {
      "name": "aiaa.apps.dextr3d.transforms.PointsToImaged",
      "args": {
        "keys": "points",
        "ref_image": "image"
      }
    },
    {
      "name": "monai.transforms.AddChanneld",
      "args": {
        "keys": [
          "image",
          "points"
        ]
      }
    },
    {
      "name": "monai.transforms.Spacingd",
      "args": {
        "keys": [
          "image",
          "points"
        ],
        "pixdim": [
          1.0,
          1.0,
          1.0
        ]
      }
    },
    {
      "name": "monai.transforms.ScaleIntensityRanged",
      "args": {
        "keys": "image",
        "a_min": -1024,
        "a_max": 1024,
        "b_min": -1.0,
        "b_max": 1.0,
        "clip": true
      }
    },
    {
      "name": "aiaa.apps.dextr3d.transforms.CropForegroundd",
      "args": {
        "keys": [
          "image",
          "points"
        ],
        "source_key": "points",
        "margin": 20
      }
    },
    {
      "name": "monai.transforms.AddExtremePointsChanneld",
      "args": {
        "keys": "image",
        "label_key": "points",
        "sigma": 3,
        "pert": 0
      }
    },
    {
      "name": "aiaa.apps.dextr3d.transforms.Resized",
      "args": {
        "keys": "image",
        "shape": [
          128,
          128,
          128
        ],
        "device": "cuda"
      }
    }
  ],
  "inference": {
    "input": "image",
    "output": "pred",
    "AIAA": {
      "name": "custom_inference.CustomInference",
      "args": {
        "is_batched_data": false
      }
    }
  },
  "post_transforms": [
    {
      "name": "monai.transforms.AddChanneld",
      "args": {
        "keys": "pred"
      }
    },
    {
      "name": "monai.transforms.Activationsd",
      "args": {
        "keys": "pred",
        "softmax": true
      }
    },
    {
      "name": "monai.transforms.AsDiscreted",
      "args": {
        "keys": "pred",
        "argmax": true
      }
    },
    {
      "name": "monai.transforms.SqueezeDimd",
      "args": {
        "keys": "pred",
        "dim": 0
      }
    },
    {
      "name": "monai.transforms.ToNumpyd",
      "args": {
        "keys": "pred"
      }
    },
    {
      "name": "aiaa.apps.dextr3d.transforms.ReverseResized",
      "args": {
        "keys": "pred",
        "ref_shape_key": "image",
        "device": "cuda"
      }
    },
    {
      "name": "aiaa.apps.dextr3d.transforms.RestoreCroppedLabeld",
      "args": {
        "keys": "pred",
        "ref_image": "image"
      }
    }
  ],
  "writer": {
    "name": "aiaa.transforms.Writer",
    "args": {
      "image": "pred",
      "json": "result"
    }
  }
}

Let’s save this config to config_aiaa.json then we can load this model using the curl command below:

curl -X PUT "http://127.0.0.1:$AIAA_PORT/admin/model/custom_model" \
     -F "config=@config_aiaa.json;type=application/json" \
     -F "data=@[where you store the model]/model.ts"

Note

Note that this custom inference only works with the AIAA backend.