Passing Data to Triton Inference Server#
This document provides an overview of how to pass data to a Triton Inference Server model, in this case, a Python backend model named “prediction_and_shapley”. According to config.pbtxt file generated during training, the model requires node features, edge connectivity, node feature masks for node attributes and—if used during training—edge features and edge feature masks for edge attributes.
Model Configuration#
Inputs#
NODE_FEATURES (for each node type)
Data Type: FP32
Shape:
[ -1, <NUM_INPUT_FEATURE> ]Interpretation: A dynamic batch of node features, where each sample consists of <NUM_INPUT_FEATURE> floating-point values.
EDGE_INDICES (for each edge type)
Data Type: INT64
Shape:
[ 2, -1 ]Interpretation: A tensor representing edge indices in a graph. The first dimension is fixed (2 rows) while the second dimension is dynamic (number of edges).
EDGE_FEATURES (for each edge type that has edge attributes)
Data Type: FP32
Shape:
[ -1, <NUM_INPUT_FEATURE> ]Interpretation: A dynamic batch of edge features, where each sample consists of <NUM_INPUT_FEATURE> floating-point values.
COMPUTE_SHAP
Data Type: BOOL
Shape:
[ 1 ]Interpretation: A single boolean flag (wrapped in an array) to indicate whether SHAP values should be computed.
FEATURE_MASK (for each node type and each edge type that has edge attributes)
Data Type:
INT32Shape:
[ <NUM_INPUT_FEATURE> ]Interpretation: The
FEATURE_MASKdefines how individual node or edge features are grouped together for operations such as SHAP (Shapley) value computation. Each element in the mask is an integer ranging from0tonumber_of_features - 1. Features that belong to the same logical group should share the same mask value. This ensures that correlated or multi-dimensional representations of a single feature are treated as one unit during SHAP computation.Example:
If a single categorical variable (e.g.,customer_id) is represented as a 10-dimensional embedding, all 10 dimensions corresponding to that variable should be assigned the same mask value in the correspondingFEATURE_MASKarray.
Outputs#
PREDICTION
Data Type: FP32
Shape:
[ -1, 1 ]Interpretation: The model’s predictions for each target edge.
SHAP_VALUES (for each node type and each edge type that has edge attributes)
Data Type: FP32
Shape:
[ -1, D ]Interpretation: SHAP values for node and edge attributes.
Let’s take a closer look at the I/O section of a sample configuration protobuf:
name: "prediction_and_shapley"
backend: "python"
input [
{
name: "x_merchant"
data_type: TYPE_FP32
dims: [ -1, 24 ]
},
{
name: "x_user"
data_type: TYPE_FP32
dims: [ -1, 13 ]
},
{
name: "COMPUTE_SHAP"
data_type: TYPE_BOOL
dims: [ 1 ]
},
{
name: "feature_mask_merchant"
data_type: TYPE_INT32
dims: [ 24 ]
},
{
name: "feature_mask_user"
data_type: TYPE_INT32
dims: [ 13 ]
},
{
name: "edge_index_user_to_merchant"
data_type: TYPE_INT64
dims: [ 2, -1 ]
},
{
name: "edge_attr_user_to_merchant"
data_type: TYPE_FP32
dims: [ -1, 38 ]
},
{
name: "edge_feature_mask_user_to_merchant"
data_type: TYPE_INT32
dims: [ 38 ]
}
]
output [
{
name: "PREDICTION"
data_type: TYPE_FP32
dims: [ -1, 1 ]
},
{
name: "shap_values_merchant"
data_type: TYPE_FP32
dims: [ -1, 24 ]
},
{
name: "shap_values_user"
data_type: TYPE_FP32
dims: [ -1, 13 ]
},
{
name: "shap_values_user_to_merchant"
data_type: TYPE_FP32
dims: [ -1, 38 ]
}
]
According the I/O specification show above, the model requires node and edge features, graph connectivity, and a control input specifying whether SHAP values should be computed.
The prediction_and_shapley model expects floating-point tensors representing user and merchant nodes, and user_to_merchant edges, along with edge connectivity information and control input for SHAP computation indicating whether SHAP values should be computed.
The -1 dimensions indicate that the number of users, merchants, and edges can vary dynamically.
Input Name |
Data Type |
Shape |
Description |
|---|---|---|---|
|
|
|
Floating-point tensor of merchant features. Each merchant node has 24 features. The |
|
|
|
Floating-point tensor of user features. Each user node has 13 features. |
|
|
|
A boolean flag that controls whether to compute SHAP values ( |
|
|
|
Integer mask for selecting or perturbing merchant features during SHAP computation (e.g., 0 or 1 per feature). |
|
|
|
Integer mask for selecting or perturbing user features. |
|
|
|
edge connectivity defining user → merchant edges. The two rows represent source and destination node indices. The number of edges can vary. |
|
|
|
Floating-point tensor of edge features (user–merchant relationships). Each edge has 38 features. |
|
|
|
Integer mask for selecting or perturbing edge features during SHAP computation. |
The model produces one main predictive output and three explainability outputs (SHAP values).
Each SHAP output corresponds to a part of the model input: merchant, user, and edge features.
These outputs together provide both predictions and feature-level explanations showing how each feature contributed to the result.
Output Name |
Data Type |
Shape |
Description |
|---|---|---|---|
|
|
|
The model’s primary output — a predicted score or probability. The batch size is variable ( |
|
|
|
SHAP (Shapley) feature attributions for merchant features. Each merchant node has up to 24 SHAP values explaining contribution each raw attribute to the prediction. |
|
|
|
SHAP feature attributions for user features. Each user node has up to 13 SHAP values explaining contribution each raw attribute to the prediction. |
|
|
|
SHAP feature attributions for edge (interaction) features between users and merchants. Each edge has up to 38 SHAP values indicating how raw relationship attributes affected the prediction. |
PREDICTION
Represents the model’s primary output, a probability for the prediction task.shap_values_*tensors
Represent feature-level explanations and are computed only if the input flagCOMPUTE_SHAP = True.SHAP outputs correspond directly to specific input tensors:
x_merchant→shap_values_merchantx_user→shap_values_useredge_attr_user_to_merchant→shap_values_user_to_merchant
Additional Model Parameters#
The model configuration includes parameters such as "in_channels", "hidden_channels", "out_channels", "n_hops", and file paths like "embedding_generator_model_state_dict" and "embeddings_based_xgboost_model". These parameters configure model internals. While they do not affect how data is passed at inference time, they determine how the backend processes the inputs.
Preparing and Passing Data#
When passing data to this model, ensure that each input is a NumPy array (or a similar structure) with the correct shape and data type.
Data Types and Shape#
Ensure that each array conforms to the data type (e.g., FP32 for floats, INT64 for integers, BOOL for booleans, etc.) and dimensions specified in the model configuration protobuf file.
Dynamic dimensions (denoted by -1) provide flexibility in the number of samples or edges, as long as the fixed dimensions — <NUM_INPUT_FEATURE> for node and edge features and 2 for edge index rows — are preserved.
Using the Triton Client Libraries#
You need to prepare batch data in a similar way as you prepared your training data. For demonstration purposes, the following code snippet uses random data. Also, the feature mask should be prepared in a way that the indices corresponding to an encoded feature should have the same value.
Example with HTTP Client (Python)#
Below is an example code snippet using the Triton Python HTTP client to create an inference request, using randomly generated data.
import numpy as np
from tritonclient.http import InferenceServerClient, InferInput, InferRequestedOutput
HOST = # Host IP
HTTP_PORT = # HTTP port
def make_example_data(
num_merchants=5,
num_users=7,
num_edges=3,
merchant_feature_dim=24,
user_feature_dim=13,
user_to_merchant_feature_dim=38,
nr_user_raw_attributes=2,
nr_merchant_raw_attributes=4,
nr_user_to_merchant_raw_attributes=6,
):
# Node features
x_merchant = np.random.randn(num_merchants, merchant_feature_dim).astype(np.float32)
x_user = np.random.randn(num_users, user_feature_dim).astype(np.float32)
# shap flag and node feature masks
compute_shap = np.array([True], dtype=np.bool_)
feature_mask_merchant = np.random.randint(
0, nr_merchant_raw_attributes, size=(merchant_feature_dim,), dtype=np.int32
)
feature_mask_user = np.random.randint(
0, nr_user_raw_attributes, size=(user_feature_dim,), dtype=np.int32
)
# edges: index [2, num_edges] and attributes [num_edges, user_to_merchant_feature_dim]
edge_index_user_to_merchant = np.vstack(
[
np.random.randint(0, num_users, size=(num_edges,)),
np.random.randint(0, num_merchants, size=(num_edges,)),
]
).astype(np.int64)
# edge features
edge_attr_user_to_merchant = np.random.randn(
num_edges, user_to_merchant_feature_dim
).astype(np.float32)
# edge feature mask,
feature_mask_user_to_merchant = np.random.randint(
0,
nr_user_to_merchant_raw_attributes,
size=(user_to_merchant_feature_dim,),
dtype=np.int32,
)
return {
"x_merchant": x_merchant,
"x_user": x_user,
"COMPUTE_SHAP": compute_shap,
"feature_mask_merchant": feature_mask_merchant,
"feature_mask_user": feature_mask_user,
"edge_index_user_to_merchant": edge_index_user_to_merchant,
"edge_attr_user_to_merchant": edge_attr_user_to_merchant,
"edge_feature_mask_user_to_merchant": feature_mask_user_to_merchant,
}
def prepare_and_send_inference_request(data):
# Connect to Triton
client = httpclient.InferenceServerClient(url=f"{HOST}:{HTTP_PORT}")
# Prepare Inputs
inputs = []
def _add_input(name, arr, dtype):
inp = InferInput(name, arr.shape, datatype=dtype)
inp.set_data_from_numpy(arr)
inputs.append(inp)
for key, value in data.items():
if key.startswith("x_"):
dtype = "FP32"
elif key.startswith("feature_mask_"):
dtype = "INT32"
elif key.startswith("edge_feature_mask_"):
dtype = "INT32"
elif key.startswith("edge_index_"):
dtype = "INT64"
elif key.startswith("edge_attr_"):
dtype = "FP32"
elif key == "COMPUTE_SHAP":
dtype = "BOOL"
else:
continue # skip things we don't care about
_add_input(key, value, dtype)
# Outputs
outputs = [InferRequestedOutput("PREDICTION")]
for key in data:
if key.startswith("x_"):
node = key[len("x_") :] # extract node name
outputs.append(InferRequestedOutput(f"shap_values_{node}"))
elif key.startswith("edge_attr_"):
edge_name = key[len("edge_attr_") :] # extract edge name
outputs.append(InferRequestedOutput(f"shap_values_{edge_name}"))
# Send request
model_name = "prediction_and_shapley"
response = client.infer(
model_name, inputs=inputs, request_id=str(1), outputs=outputs, timeout=3000
)
result = {}
# always include prediction
result["PREDICTION"] = response.as_numpy("PREDICTION")
# add shap values
for key in data:
if key.startswith("x_"):
node = key[len("x_") :] # e.g. "merchant", "user"
result[f"shap_values_{node}"] = response.as_numpy(f"shap_values_{node}")
if key.startswith("edge_attr_"):
edge_name = key[len("edge_attr_") :] # e.g. ("user" "to" "merchant")
result[f"shap_values_{edge_name}"] = response.as_numpy(
f"shap_values_{edge_name}"
)
return result
test_data = make_example_data()
compute_shap = True
result = prepare_and_send_inference_request(
test_data | {"COMPUTE_SHAP": np.array([compute_shap], dtype=np.bool_)}
)
Note that, the Triton Python client libraries (both HTTP and gRPC) simplify the process by handling serialization, but you can also send raw JSON payloads via REST API if needed.