PyTorch Support in AIAA

There are two ways of using PyTorch trained models in AIAA.

  1. Utilize the Triton inference server as backend.

  2. Use PyTorch directly in AIAA.

Triton Inference Server

To use models in AIAA, you need to provide a correct model and config.

The model need to be in TorchScript format. Please refer to Convert PyTorch trained network.

Some guidelines on preparing the config for AIAA:

  • The config need to have section “trtis” and the “platform” need to be “pytorch_libtorch”.

  • The model input is called “INPUT__x” where x starts from 0.

  • The model output is called “OUTPU__x” where x starts from 0.

  • The node mapping means which keys in the transforms match to which input/output of network.

For example:

{
  "inference": {
    "name": "TRTISInference",
    "node_mapping": {
      "INPUT__0": "image",
      "OUTPUT__0": "model"
    },
    "trtis": {
    "platform": "pytorch_libtorch",
    "input": [
      {
        "name": "INPUT__0",
        "data_type": "TYPE_FP32",
        "dims": [3, 256, 256]
      },
      {
        "name": "INPUT__1",
        "data_type": "TYPE_FP32",
        "dims": [3, 256, 256]
      }
    ],
    "output": [
      {
        "name": "OUTPUT__0",
        "data_type": "TYPE_FP32",
        "dims": [1, 256, 256]
      }
    ]
    }
  }
}

Once you have model.pt (in TorchScript format), you can load the model into AIAA as follows:

curl -X PUT "http://127.0.0.1/admin/model/segmentation_2d_brain" \
     -F "config=@config_aiaa.json;type=application/json" \
     -F "data=@model.pt"

Note

Please refer to Triton documentation for more details.

Use PyTorch directly

To run specific inferences using native PyTorch, you can write your own inference and specify the flag native=true when uploading the inference. The advantage is that you don’t have to convert your PyTorch saved model to the TorchScript format.

For example:

curl -X PUT "http://127.0.0.1/admin/model/custom_model?native=true" \
     -F "config=@config_aiaa.json;type=application/json" \
     -F "data=@[where you store model]/state_dict.pt"

Please refer to Bring your own Inference for a complete example.