Class InferenceMemoryNLP
Defined in File inference_memory_nlp.hpp
Base Type
public morpheus::InferenceMemory
(Class InferenceMemory)
-
class InferenceMemoryNLP : public morpheus::InferenceMemory
This is a container class for data that needs to be submitted to the inference server for NLP category usecases.
Public Functions
-
InferenceMemoryNLP(TensorIndex count, TensorObject &&input_ids, TensorObject &&input_mask, TensorObject &&seq_ids)
Construct a new Inference Memory NLP object.
- Parameters
count – : Number of messages
input_ids – : The token-ids for each string padded with 0s to max_length
input_mask – : The mask for token-ids result where corresponding positions identify valid token-id values
seq_ids – : Ids used to index from an inference input to a message. Necessary since there can be more inference inputs than messages (i.e., if some messages get broken into multiple inference requests)
-
const TensorObject &get_input_ids() const
Get the input ids object.
- Throws
- Returns
std::runtime_error – If no tensor named “input_ids” exists
const TensorObject&
-
const TensorObject &get_input_mask() const
Get the input mask object.
- Throws
- Returns
std::runtime_error – If no tensor named “input_mask” exists
const TensorObject&
-
const TensorObject &get_seq_ids() const
Get the seq ids object.
- Throws
- Returns
std::runtime_error – If no tensor named “seq_ids” exists
const TensorObject&
-
void set_input_ids(TensorObject &&input_ids)
Set the input ids object.
- Parameters
- Throws
input_ids –
std::length_error – If the number of rows in
input_ids
does not matchcount
.
-
void set_input_mask(TensorObject &&input_mask)
Set the input mask object.
- Parameters
- Throws
input_mask –
std::length_error – If the number of rows in
input_mask
does not matchcount
.
-
void set_seq_ids(TensorObject &&seq_ids)
Set the seq ids object.
- Parameters
- Throws
seq_ids –
std::length_error – If the number of rows in
seq_ids
does not matchcount
.
-
InferenceMemoryNLP(TensorIndex count, TensorObject &&input_ids, TensorObject &&input_mask, TensorObject &&seq_ids)