Struct MatxUtil#
Defined in File matx_util.hpp
Struct Documentation#
-
struct MatxUtil#
Public Static Functions
- static std::shared_ptr<rmm::device_buffer> cast(
- const DevMemInfo &input,
- TypeId output_type
Convert one device_buffer type to another.
- Parameters:
input –
output_type –
- Returns:
std::shared_ptr<rmm::device_buffer>
- TensorIndex row_count,
- TensorIndex fea_len,
- TypeId output_type,
- std::shared_ptr<MemoryDescriptor> md,
- TensorIndex start_idx = 0
Builds a Nx3 segment ID matrix.
- Parameters:
row_count –
fea_len –
output_type –
start_idx –
- Returns:
std::shared_ptr<rmm::device_buffer>
- static void offset_seq_ids(
- const DevMemInfo &input,
- TensorIndex offset
Adds a constant offset to a seg_ids tensor.
- Parameters:
input –
offset –
- static std::shared_ptr<rmm::device_buffer> logits(
- const DevMemInfo &input
Calculate logits on device_buffer.
- Parameters:
input –
- Returns:
std::shared_ptr<rmm::device_buffer>
- static std::shared_ptr<rmm::device_buffer> transpose(
- const DevMemInfo &input
Perform transpose.
- Parameters:
input –
- Returns:
std::shared_ptr<rmm::device_buffer>
- static std::shared_ptr<rmm::device_buffer> threshold(
- const DevMemInfo &input,
- double thresh_val,
- bool by_row
Returns an array of boolean where x[i,j] >= thresh_val, when by_row is true an Nx1 array will be returned with a true if any value in the row is above the threshold.
- Parameters:
input –
thresh_val –
by_row –
- Returns:
std::shared_ptr<rmm::device_buffer>
- static std::shared_ptr<rmm::device_buffer> reduce_max(
- const DevMemInfo &input,
- const ShapeType &seq_ids,
- TensorIndex seq_id_offset,
- const ShapeType &output_shape
Returns a buffer with
output_shape
containing the max value from values ininput
mapped according toseq_ids
. Ex given a hypothetical input of:Will return: [8, 9, 8, 2]input = [5, 2, 8, 9, 8, 2, 1] seq_ids = [0, 0, 0, 1, 2, 3, 3]
- Parameters:
input –
seq_ids –
seq_id_offset –
output_shape –
- Returns:
std::shared_ptr<rmm::device_buffer>