Bring your own Inference

By default, AIAA is using TRTISInference. This means the inference requests are sending to Triton and let Triton leverage multiple GPUs.

You can also write your own inference procedure by extending the base class below, just make sure you put your custom inference in <workspace>/lib folder.


We suggest using TRTISInference to better utilize the GPU.


Below is the base class of Inference.

class Inference(object):

    def inference(self, name, data, config: ModelConfig, triton_config):
        """Defines an inference procedure."""

    def close(self):
        """Closes any resources you might have opened."""


Following is an example of a CustomInference.

import logging

import torch
import numpy as np
from skimage import util

from ai4med.common.medical_image import MedicalImage
from ai4med.common.transform_ctx import TransformContext
from nvmidl.apps.aas.configs.modelconfig import ModelConfig
from nvmidl.apps.aas.inference.inference import Inference
from nvmidl.apps.aas.inference.inference_utils import InferenceUtils

class CustomInference(Inference):

    def inference(self, name, data, config: ModelConfig, triton_config):
        logger = logging.getLogger(__name__)'Run Custom Inference for: {}'.format(name))

        # Better to use Clara Transformers (extended based on the same to take max benefits of MedicalImage)
        assert isinstance(data, TransformContext)

        transform_ctx: TransformContext = data
        img: MedicalImage = transform_ctx.get_image('image')

        shape_fmt = img.get_shape_format()'Shape Format: {}; Current Shape: {}'.format(shape_fmt, img.get_data().shape))

        # Do Anything
        model_file = config.get_path()'Available Model File (you can do something of your choice): {}'.format(model_file))

        image = img.get_data()[0]
        image_inverted = util.invert(image)
        image_inverted = np.expand_dims(image_inverted, axis=0) # add batch dim'Image: {}; Inverted: {}'.format(image.shape, image_inverted.shape))

        feed_dict = { "image": torch.from_numpy(image_inverted) }

        # Get your own network here
        # assume you put your PyTorch model code in
        # and the model class is called CustomModel
        network_config = {
            "path": "custom_network.CustomModel",
            "args": {
                "input_channels": 1,
                "output_channels": 1
        # assume the state_dict is uploaded to AIAA via "-F"
        # then it will be stored in config.model_path
        network = InferenceUtils.init_external_class(network_config)
        with torch.no_grad():
            outputs = network(**feed_dict)

        outputs = outputs.detach().cpu().numpy()
        # Produce Output
        transform_ctx.set_image('model', MedicalImage(outputs[0], shape_fmt))
        return transform_ctx

Let’s assume we save the above code in And we put both and in <workspace>/lib folder. The example config_aiaa.json is provided below:

  "version": "3",
  "type": "others",
  "labels": [
  "description": "Custom Model to solve some xyz use case",
  "pre_transforms": [
      "name": "LoadNifti",
      "args": {
        "fields": "image",
        "as_closest_canonical": "false"
      "name": "ConvertToChannelsFirst",
      "args": {
        "fields": "image"
  "inference": {
    "name": "custom_inference.CustomInference",
    "args": {
  "post_transforms": [
      "name": "CopyProperties",
      "args": {
        "fields": [
        "from_field": "image",
        "properties": [
  "writer": {
    "name": "WriteNifti",
    "args": {
      "field": "model"

Once we have the custom code put in right place then we can load this model using the curl command below:

curl -X PUT "$LOCAL_PORT/admin/model/custom_inference?native=true" \
     -F "config=@config_aiaa.json;type=application/json" \
     -F "data=@[where you store model]/"


Note that in this example we are using native PyTorch to do serving. So we need to add a flag native=true when uploading models into AIAA with Triton backend. Or you can just start AIAA with --engine AIAA.